Layers

Metalhead also defines a module called Layers which contains some custom layers that are used to configure the models in Metalhead. These layers are not available in Flux at present. To use the functions defined in the Layers module, you need to import it.

using Metalhead: Layers

This page contains the API reference for the Layers module.

Warning

The Layers module is still a work in progress. While we will endeavour to keep the API stable, we cannot guarantee that it will not change in the future. If you find any of the functions in this module do not work as expected, please open an issue on GitHub.

Convolution + BatchNorm layers

Metalhead.Layers.conv_normFunction
conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
          activation = relu; norm_layer = BatchNorm, revnorm::Bool = false,
          preact::Bool = false, stride::Integer = 1, pad::Integer = 0,
          dilation::Integer = 1, groups::Integer = 1, [bias, weight, init])

Create a convolution + normalisation layer pair with activation.

Arguments

  • kernel_size: size of the convolution kernel (tuple)
  • inplanes: number of input feature maps
  • outplanes: number of output feature maps
  • activation: the activation function for the final layer
  • norm_layer: the normalisation layer used. Note that using identity as the normalisation layer will result in no normalisation being applied. (This is only compatible with preact and revnorm both set to false.)
  • revnorm: set to true to place the normalisation layer before the convolution
  • preact: set to true to place the activation function before the normalisation layer (only compatible with revnorm = false)
  • bias: bias for the convolution kernel. This is set to false by default if norm_layer is not identity and true otherwise.
  • stride: stride of the convolution kernel
  • pad: padding of the convolution kernel
  • dilation: dilation of the convolution kernel
  • groups: groups for the convolution kernel
  • weight, init: initialization for the convolution kernel (see Flux.Conv)
source
Metalhead.Layers.basic_conv_bnFunction
basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu;
              kwargs...)

Returns a convolution + batch normalisation pair with activation as used by the Inception family of models with default values matching those used in the official TensorFlow implementation.

Arguments

  • kernel_size: size of the convolution kernel (tuple)
  • inplanes: number of input feature maps
  • outplanes: number of output feature maps
  • activation: the activation function for the final layer
  • batchnorm: set to true to include batch normalization after each convolution
  • kwargs: keyword arguments passed to conv_norm
source

These blocks are designed to be used in convolutional neural networks. Most of these are used in the MobileNet and EfficientNet family of models, but they also feature in "fancier" versions of well known-models like ResNet (SE-ResNet).

Metalhead.Layers.dwsep_conv_normFunction
dwsep_conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
                activation = relu; norm_layer = BatchNorm, stride::Integer = 1,
                bias::Bool = !(norm_layer !== identity), pad::Integer = 0, [bias, weight, init])

Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers:

  • a kernel_size depthwise convolution from inplanes => inplanes
  • a (batch) normalisation layer + activation (if norm_layer !== identity; otherwise activation is applied to the convolution output)
  • a kernel_size convolution from inplanes => outplanes
  • a (batch) normalisation layer + activation (if norm_layer !== identity; otherwise activation is applied to the convolution output)

See Fig. 3 in reference.

Arguments

  • kernel_size: size of the convolution kernel (tuple)
  • inplanes: number of input feature maps
  • outplanes: number of output feature maps
  • activation: the activation function for the final layer
  • norm_layer: the normalisation layer used. Note that using identity as the normalisation layer will result in no normalisation being applied.
  • bias: whether to use bias in the convolution layers.
  • stride: stride of the first convolution kernel
  • pad: padding of the first convolution kernel
  • weight, init: initialization for the convolution kernel (see Flux.Conv)
source
Metalhead.Layers.mbconvFunction
mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer,
       outplanes::Integer, activation = relu; stride::Integer,
       reduction::Union{Nothing, Real} = nothing,
       se_round_fn = x -> round(Int, x), norm_layer = BatchNorm, kwargs...)

Create a basic inverted residual block for MobileNet and Efficient variants. This is a sequence of layers:

  • a 1x1 convolution from inplanes => explanes followed by a (batch) normalisation layer

  • activation if inplanes != explanes

  • a kernel_size depthwise separable convolution from explanes => explanes

  • a (batch) normalisation layer

  • a squeeze-and-excitation block (if reduction != nothing) from explanes => se_round_fn(explanes / reduction) and back to explanes

  • a 1x1 convolution from explanes => outplanes

  • a (batch) normalisation layer + activation

Warning

This function does not handle the residual connection by default. The user must add this manually to use this block as a standalone. To construct a model, check out the builders, which handle the residual connection and other details.

First introduced in the MobileNetv2 paper. (See Fig. 3 in reference.)

Arguments

  • kernel_size: kernel size of the convolutional layers
  • inplanes: number of input feature maps
  • explanes: The number of expanded feature maps. This is the number of feature maps after the first 1x1 convolution.
  • outplanes: The number of output feature maps
  • activation: The activation function for the first two convolution layer
  • stride: The stride of the convolutional kernel, has to be either 1 or 2
  • reduction: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see squeeze_excite)
  • se_round_fn: The function to round the number of reduced feature maps in the squeeze and excite layer
  • norm_layer: The normalization layer to use
source
Metalhead.Layers.fused_mbconvFunction
fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer,
             outplanes::Integer, activation = relu;
             stride::Integer, norm_layer = BatchNorm)

Create a fused inverted residual block.

This is a sequence of layers:

  • a kernel_size depthwise separable convolution from explanes => explanes
  • a (batch) normalisation layer
  • a 1x1 convolution from explanes => outplanes followed by a (batch) normalisation layer + activation if inplanes != explanes
Warning

This function does not handle the residual connection by default. The user must add this manually to use this block as a standalone. To construct a model, check out the builders, which handle the residual connection and other details.

Originally introduced by Google in EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML. Later used in the EfficientNetv2 paper.

Arguments

  • kernel_size: kernel size of the convolutional layers
  • inplanes: number of input feature maps
  • explanes: The number of expanded feature maps
  • outplanes: The number of output feature maps
  • activation: The activation function for the first two convolution layer
  • stride: The stride of the convolutional kernel, has to be either 1 or 2
  • norm_layer: The normalization layer to use
source
Metalhead.Layers.squeeze_exciteFunction
squeeze_excite(inplanes::Integer; reduction::Real = 16, round_fn = _round_channels, 
               norm_layer = identity, activation = relu, gate_activation = sigmoid)

Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE-ResNets.

Arguments

  • inplanes: The number of input feature maps
  • reduction: The reduction factor for the number of hidden feature maps in the squeeze and excite layer. The number of hidden feature maps is calculated as round_fn(inplanes / reduction).
  • round_fn: The function to round the number of reduced feature maps.
  • activation: The activation function for the first convolution layer
  • gate_activation: The activation function for the gate layer
  • norm_layer: The normalization layer to be used after the convolution layers
  • rd_planes: The number of hidden feature maps in a squeeze and excite layer
source

Normalisation, Dropout and Pooling layers

Metalhead provides various custom layers for normalisation, dropout and pooling which have been used to additionally customise various models.

Normalisation layers

Metalhead.Layers.ChannelLayerNormType
ChannelLayerNorm(sz::Integer, λ = identity; eps = 1.0f-6)

A variant of LayerNorm where the input is normalised along the channel dimension. The input is expected to have channel dimension with size sz. It also applies a learnable shift and rescaling after the normalization.

Note that this is specifically for inputs with 4 dimensions in the format (H, W, C, N) where H, W are the height and width of the input, C is the number of channels, and N is the batch size.

source
Metalhead.Layers.LayerNormV2Type
LayerNormV2(size..., λ=identity; affine=true, eps=1f-5)

Same as Flux's LayerNorm but eps is added before taking the square root in the denominator. Therefore, LayerNormV2 matches pytorch's LayerNorm.

source
Metalhead.Layers.LayerScaleFunction
LayerScale(planes::Integer, λ)

Creates a Flux.Scale layer that performs "LayerScale" (reference).

Arguments

  • planes: Size of channel dimension in the input.
  • λ: initialisation value for the learnable diagonal matrix.
source

Dropout layers

Metalhead.Layers.DropBlockType
DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, [rng])

The DropBlock layer. While training, it zeroes out continguous regions of size block_size in the input. During inference, it simply returns the input x. It can be used in two ways: either with all blocks having the same survival probability or with a linear scaling rule across the blocks. This is performed only at training time. At test time, the DropBlock layer is equivalent to identity.

(reference)

Arguments

  • drop_block_prob: probability of dropping a block. If nothing is passed, it returns identity. Note that some literature uses the term "survival probability" instead, which is equivalent to 1 - drop_block_prob.
  • block_size: size of the block to drop
  • gamma_scale: multiplicative factor for gamma used. For the calculation of gamma, refer to the paper.
  • rng: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU.
source
Metalhead.Layers.dropblockFunction
dropblock([rng], x::AbstractArray{T, 4}, drop_block_prob, block_size,
          gamma_scale, active::Bool = true)

The dropblock function. If active is true, for each input, it zeroes out continguous regions of size block_size in the input. Otherwise, it simply returns the input x.

Arguments

  • rng: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU.
  • x: input array
  • drop_block_prob: probability of dropping a block. If nothing is passed, it returns identity.
  • block_size: size of the block to drop
  • gamma_scale: multiplicative factor for gamma used. For the calculations, refer to the paper.

If you are not a package developer, you most likely do not want this function. Use DropBlock instead.

source
Metalhead.Layers.StochasticDepthFunction
StochasticDepth(p, mode = :row; [rng])

Implements Stochastic Depth. This is a Dropout layer from Flux that drops values with probability p. (reference)

This layer can be used to drop certain blocks in a residual structure and allow them to propagate completely through the skip connection. It can be used in two ways: either with all blocks having the same survival probability or with a linear scaling rule across the blocks. This is performed only at training time. At test time, the StochasticDepth layer is equivalent to identity.

Arguments

  • p: probability of Stochastic Depth. Note that some literature uses the term "survival probability" instead, which is equivalent to 1 - p.
  • mode: Either :batch or :row. :batch randomly zeroes the entire input, row zeroes randomly selected rows from the batch. The default is :row.
  • rng: can be used to pass in a custom RNG instead of the default. See Flux.Dropout for more information on the behaviour of this argument. Custom RNGs are only supported on the CPU.
source

Pooling layers

Metalhead.Layers.AdaptiveMeanMaxPoolFunction
AdaptiveMeanMaxPool([connection = +], output_size::Tuple = (1, 1))

A type of adaptive pooling layer which uses both mean and max pooling and combines them to produce a single output. Note that this is equivalent to Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)). When connection is not specified, it defaults to +.

Arguments

  • connection: The connection type to use.
  • output_size: The size of the output after pooling.
source

Classifier creation

Metalhead provides a function to create a classifier for neural network models that is quite flexible, and is used by the library extensively to create the classifier "head" for networks.

Metalhead.Layers.create_classifierFunction
create_classifier(inplanes::Integer, nclasses::Integer, activation = identity;
                  use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), 
                  dropout_prob = nothing)

Creates a classifier head to be used for models.

Arguments

  • inplanes: number of input feature maps
  • nclasses: number of output classes
  • activation: activation function to use
  • use_conv: whether to use a 1x1 convolutional layer instead of a Dense layer.
  • pool_layer: pooling layer to use. This is passed in with the layer instantiated with any arguments that are needed i.e. as AdaptiveMeanPool((1, 1)), for example.
  • dropout_prob: dropout probability used in the classifier head. Set to nothing to disable dropout.
source
create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer,
                  activations::NTuple{2} = (relu, identity);
                  use_conv::NTuple{2, Bool} = (false, false),
                  pool_layer = AdaptiveMeanPool((1, 1)), dropout_prob = nothing)

Creates a classifier head to be used for models with an extra hidden layer.

Arguments

  • inplanes: number of input feature maps
  • hidden_planes: number of hidden feature maps
  • nclasses: number of output classes
  • activations: activation functions to use for the hidden and output layers. This is a tuple of two elements, the first being the activation function for the hidden layer and the second for the output layer.
  • use_conv: whether to use a 1x1 convolutional layer instead of a Dense layer. This is a tuple of two booleans, the first for the hidden layer and the second for the output layer.
  • pool_layer: pooling layer to use. This is passed in with the layer instantiated with any arguments that are needed i.e. as AdaptiveMeanPool((1, 1)), for example.
  • dropout_prob: dropout probability used in the classifier head. Set to nothing to disable dropout.
source

The Layers module contains specific layers that are used to build vision transformer (ViT)-inspired models:

Metalhead.Layers.MultiHeadSelfAttentionType
MultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, 
            attn_dropout_prob = 0., proj_dropout_prob = 0.)

Multi-head self-attention layer.

Arguments

  • planes: number of input channels
  • nheads: number of heads
  • qkv_bias: whether to use bias in the layer to get the query, key and value
  • attn_dropout_prob: dropout probability after the self-attention layer
  • proj_dropout_prob: dropout probability after the projection layer
source
Metalhead.Layers.ClassTokensType
ClassTokens(planes::Integer; init = Flux.zeros32)

Appends class tokens to an input with embedding dimension planes for use in many vision transformer models.

source
Metalhead.Layers.ViPosEmbeddingType
ViPosEmbedding(embedsize::Integer, npatches::Integer; 
               init = (dims::Dims{2}) -> rand(Float32, dims))

Positional embedding layer used by many vision transformer-like models.

source
Metalhead.Layers.PatchEmbeddingFunction
PatchEmbedding(imsize::Dims{2} = (224, 224); inchannels::Integer = 3,
               patch_size::Dims{2} = (16, 16), embedplanes = 768,
               norm_layer = planes -> identity, flatten = true)

Patch embedding layer used by many vision transformer-like models to split the input image into patches.

Arguments

  • imsize: the size of the input image
  • inchannels: number of input channels
  • patch_size: the size of the patches
  • embedplanes: the number of channels in the embedding
  • norm_layer: the normalization layer - by default the identity function but otherwise takes a single argument constructor for a normalization layer like LayerNorm or BatchNorm
  • flatten: set true to flatten the input spatial dimensions after the embedding
source

Apart from this, the Layers module also contains certain blocks used in MLPMixer-style models:

Metalhead.Layers.gated_mlp_blockFunction
gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, 
          outplanes::Integer = inplanes; dropout_prob = 0.0, activation = gelu)

Feedforward block based on the implementation in the paper "Pay Attention to MLPs". (reference)

Arguments

  • gate_layer: Layer to use for the gating.
  • inplanes: Number of dimensions in the input.
  • hidden_planes: Number of dimensions in the intermediate layer.
  • outplanes: Number of dimensions in the output - by default it is the same as inplanes.
  • dropout_prob: Dropout probability.
  • activation: Activation function to use.
source
Metalhead.Layers.mlp_blockFunction
mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; 
          dropout_prob = 0., activation = gelu)

Feedforward block used in many MLPMixer-like and vision-transformer models.

Arguments

  • inplanes: Number of dimensions in the input.
  • hidden_planes: Number of dimensions in the intermediate layer.
  • outplanes: Number of dimensions in the output - by default it is the same as inplanes.
  • dropout_prob: Dropout probability.
  • activation: Activation function to use.
source

Utilities for layers

These are some miscellaneous utilities present in the Layers module, and are used with other custom/inbuilt layers to make certain common operations in neural networks easier.

Metalhead.Layers.inputscaleFunction
inputscale(λ; activation = identity)

Scale the input by a scalar λ and applies an activation function to it. Equivalent to activation.(λ .* x).

source
Metalhead.Layers.actaddFunction
actadd(activation = relu, xs...)

Convenience function for adding input arrays after applying an activation function to them. Useful as the connection argument for the block function in resnet.

source
Metalhead.Layers.addactFunction
addact(activation = relu, xs...)

Convenience function for applying an activation function to the output after summing up the input arrays. Useful as the connection argument for the block function in resnet.

source
Metalhead.Layers.cat_channelsFunction
cat_channels(x, y, zs...)

Concatenate x and y (and any zs) along the channel dimension (third dimension). Equivalent to cat(x, y, zs...; dims=3). Convenient reduction operator for use with Parallel.

source
Metalhead.Layers.flatten_chainsFunction
flatten_chains(m::Chain)
flatten_chains(m)

Convenience function for traversing nested layers of a Chain object and flatten them into a single iterator.

source
Metalhead.Layers.linear_schedulerFunction
linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth)
linear_scheduler(drop_prob::Nothing; depth::Integer)

Returns the dropout probabilities for a given depth using the linear scaling rule. Note that this returns evenly spaced values between start_value and drop_prob, not including drop_prob. If drop_prob is nothing, it returns a Vector of length depth with all values equal to nothing.

source
Metalhead.Layers.swapdimsFunction
swapdims(perm)

Convenience function that returns a closure which permutes the dimensions of an array. perm is a vector or tuple specifying a permutation of the input dimensions. Equivalent to permutedims(x, perm).

source