An introduction to the Layers module in Metalhead.jl

Since v0.8, Metalhead.jl exports a Layers module that contains a number of useful layers and utilities for building neural networks. This guide will walk you through the most commonly used layers and utilities present in the Layers module, and how to use them. It also contains some examples of how these layers are used in Metalhead.jl as well as a comprehensive API reference.

Warning

The Layers module is still a work in progress. While we will endeavour to keep the API stable, we cannot guarantee that it will not change in the future. In particular, the API may change significantly between major versions of Metalhead.jl. If you find any of the functions in this module do not work as expected, please open an issue on GitHub.

First, however, you want to make sure that the Layers module is loaded, and that the functions and types are available in your current scope. You can do this by running the following code:

using Metalhead
using Metalhead.Layers

Convolution + Normalisation: the conv_norm layer

One of the most common patterns in modern neural networks is to have a convolutional layer followed by a normalisation layer. Most major deep learning libraries have a way to combine these two layers into a single layer. In Metalhead.jl, this is done with the Metalhead.Layers.conv_norm layer. The function signature for this is given below:

Metalhead.Layers.conv_normFunction
conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
          activation = relu; norm_layer = BatchNorm, revnorm::Bool = false,
          preact::Bool = false, stride::Integer = 1, pad::Integer = 0,
          dilation::Integer = 1, groups::Integer = 1, [bias, weight, init])

Create a convolution + normalisation layer pair with activation.

Arguments

  • kernel_size: size of the convolution kernel (tuple)
  • inplanes: number of input feature maps
  • outplanes: number of output feature maps
  • activation: the activation function for the final layer
  • norm_layer: the normalisation layer used. Note that using identity as the normalisation layer will result in no normalisation being applied. (This is only compatible with preact and revnorm both set to false.)
  • revnorm: set to true to place the normalisation layer before the convolution
  • preact: set to true to place the activation function before the normalisation layer (only compatible with revnorm = false)
  • bias: bias for the convolution kernel. This is set to false by default if norm_layer is not identity and true otherwise.
  • stride: stride of the convolution kernel
  • pad: padding of the convolution kernel
  • dilation: dilation of the convolution kernel
  • groups: groups for the convolution kernel
  • weight, init: initialization for the convolution kernel (see Flux.Conv)
source

To know more about the exact details of each of these parameters, you can refer to the documentation for this function. For now, we will focus on some common use cases. For example, if you want to create a convolutional layer with a kernel size of 3x3, with 32 input channels and 64 output channels, along with a BatchNorm layer, you can do the following:

conv_norm((3, 3), 32, 64)

This returns a Vector with the desired layers. To use it in a model, the user should splat it into a Chain. For example:

Chain(Dense(3, 32), conv_norm((3, 3), 32, 64)..., Dense(64, 10))

The default activation function for conv_norm is relu, and the default normalisation layer is BatchNorm. To use a different activation function, you can just pass it in as a positional argument. For example, to use a sigmoid activation function:

conv_norm((3, 3), 32, 64, sigmoid)

Let's try something else. Suppose you want to use a GroupNorm layer instead of a BatchNorm layer. Note that norm_layer is a keyword argument in the function signature of conv_norm as shown above. Then we can write:

conv_norm((3, 3), 32, 64; norm_layer = GroupNorm)

What if you want to change certain specific parameters of the norm_layer? For example, what if you want to change the number of groups in the GroupNorm layer?

# defining the norm layer
norm_layer = planes -> GroupNorm(planes, 4)
# passing it to the conv_norm layer
conv_norm((3, 3), 32, 64; norm_layer = norm_layer)

One of Julia's features is that functions are first-class objects, and can be passed around as arguments to other functions. Here, we have create an anonymous function that takes in the number of planes as an argument, and returns a GroupNorm layer with 4 groups. This is then passed to the norm_layer keyword argument of the conv_norm layer. Using anonymous functions allows us to configure the layers in a very flexible manner, and this is a common pattern in Metalhead.jl.

Let's take a slightly more complicated example. TensorFlow uses different defaults for its normalisation layers. In particular, it uses an epsilon value of 1e-3 for BatchNorm layers. If you want to use the same defaults as TensorFlow, you can do the following:

# note that 1e-3 is not a Float32 and Flux is optimized for Float32, so we use 1.0f-3
conv_norm((3, 3), 32, 64; norm_layer = planes -> BatchNorm(planes, eps = 1.0f-3))

which, incidentally, is very similar to the code Metalhead uses internally for the Metalhead.Layers.basic_conv_bn layer that is used in the Inception family of models.

Metalhead.Layers.basic_conv_bnFunction
basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu;
              kwargs...)

Returns a convolution + batch normalisation pair with activation as used by the Inception family of models with default values matching those used in the official TensorFlow implementation.

Arguments

  • kernel_size: size of the convolution kernel (tuple)
  • inplanes: number of input feature maps
  • outplanes: number of output feature maps
  • activation: the activation function for the final layer
  • batchnorm: set to true to include batch normalization after each convolution
  • kwargs: keyword arguments passed to conv_norm
source

Normalisation layers

The Layers module provides some custom normalisation functions that are not present in Flux.

Metalhead.Layers.LayerScaleFunction
LayerScale(planes::Integer, λ)

Creates a Flux.Scale layer that performs "LayerScale" (reference).

Arguments

  • planes: Size of channel dimension in the input.
  • λ: initialisation value for the learnable diagonal matrix.
source
Metalhead.Layers.LayerNormV2Type
LayerNormV2(size..., λ=identity; affine=true, eps=1f-5)

Same as Flux's LayerNorm but eps is added before taking the square root in the denominator. Therefore, LayerNormV2 matches pytorch's LayerNorm.

source
Metalhead.Layers.ChannelLayerNormType
ChannelLayerNorm(sz::Integer, λ = identity; eps = 1.0f-6)

A variant of LayerNorm where the input is normalised along the channel dimension. The input is expected to have channel dimension with size sz. It also applies a learnable shift and rescaling after the normalization.

Note that this is specifically for inputs with 4 dimensions in the format (H, W, C, N) where H, W are the height and width of the input, C is the number of channels, and N is the batch size.

source

There is also a utility function, prenorm, which applies a normalisation layer before a given block and simply returns a Chain with the normalisation layer and the block. This is useful for creating Vision Transformers (ViT)-like models.

Metalhead.Layers.prenormFunction
prenorm(planes, block; norm_layer = LayerNorm)

Utility function to apply a normalization layer before a block.

Arguments

  • planes: Size of dimension to normalize.
  • block: The block before which the normalization layer is applied.
  • norm_layer: The normalization layer to use.
source

Dropout layers

The Layers module provides two dropout-like layers not present in Flux:

Metalhead.Layers.DropBlockType
DropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, [rng])

The DropBlock layer. While training, it zeroes out continguous regions of size block_size in the input. During inference, it simply returns the input x. It can be used in two ways: either with all blocks having the same survival probability or with a linear scaling rule across the blocks. This is performed only at training time. At test time, the DropBlock layer is equivalent to identity.

(reference)

Arguments

  • drop_block_prob: probability of dropping a block. If nothing is passed, it returns identity. Note that some literature uses the term "survival probability" instead, which is equivalent to 1 - drop_block_prob.
  • block_size: size of the block to drop
  • gamma_scale: multiplicative factor for gamma used. For the calculation of gamma, refer to the paper.
  • rng: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU.
source
Metalhead.Layers.StochasticDepthFunction
StochasticDepth(p, mode = :row; [rng])

Implements Stochastic Depth. This is a Dropout layer from Flux that drops values with probability p. (reference)

This layer can be used to drop certain blocks in a residual structure and allow them to propagate completely through the skip connection. It can be used in two ways: either with all blocks having the same survival probability or with a linear scaling rule across the blocks. This is performed only at training time. At test time, the StochasticDepth layer is equivalent to identity.

Arguments

  • p: probability of Stochastic Depth. Note that some literature uses the term "survival probability" instead, which is equivalent to 1 - p.
  • mode: Either :batch or :row. :batch randomly zeroes the entire input, row zeroes randomly selected rows from the batch. The default is :row.
  • rng: can be used to pass in a custom RNG instead of the default. See Flux.Dropout for more information on the behaviour of this argument. Custom RNGs are only supported on the CPU.
source

DropBlock also has a functional variant present in the Layers module:

Metalhead.Layers.dropblockFunction
dropblock([rng], x::AbstractArray{T, 4}, drop_block_prob, block_size,
          gamma_scale, active::Bool = true)

The dropblock function. If active is true, for each input, it zeroes out continguous regions of size block_size in the input. Otherwise, it simply returns the input x.

Arguments

  • rng: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU.
  • x: input array
  • drop_block_prob: probability of dropping a block. If nothing is passed, it returns identity.
  • block_size: size of the block to drop
  • gamma_scale: multiplicative factor for gamma used. For the calculations, refer to the paper.

If you are not a package developer, you most likely do not want this function. Use DropBlock instead.

source

Both DropBlock and StochasticDepth are used along with probability values that vary based on a linear schedule across the structure of the model (see the respective papers for more details). The Layers module provides a utility function to create such a schedule as well:

Metalhead.Layers.linear_schedulerFunction
linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth)
linear_scheduler(drop_prob::Nothing; depth::Integer)

Returns the dropout probabilities for a given depth using the linear scaling rule. Note that this returns evenly spaced values between start_value and drop_prob, not including drop_prob. If drop_prob is nothing, it returns a Vector of length depth with all values equal to nothing.

source

The Metalhead.resnet function which powers the ResNet family of models in Metalhead.jl is configured to allow the use of both these layers. For examples, check out the guide for using the ResNet family in Metalhead here. These layers can also be used by the user to construct other custom models.

Pooling layers

The Layers module provides a Metalhead.Layers.AdaptiveMeanMaxPool layer, which is inspired by a similar layer present in timm.

Metalhead.Layers.AdaptiveMeanMaxPoolFunction
AdaptiveMeanMaxPool([connection = +], output_size::Tuple = (1, 1))

A type of adaptive pooling layer which uses both mean and max pooling and combines them to produce a single output. Note that this is equivalent to Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size)). When connection is not specified, it defaults to +.

Arguments

  • connection: The connection type to use.
  • output_size: The size of the output after pooling.
source

Many mid-level model functions in Metalhead.jl have been written to support passing custom pooling layers to them if applicable (either in the model itself or in the classifier head). For example, the Metalhead.resnet function supports this, and examples of this can be found in the guide for using the ResNet family in Metalhead here.

Classifier creation

Metalhead provides a function to create a classifier for neural network models that is quite flexible, and is used by the library extensively to create the classifier "head" for networks. This function is called Metalhead.Layers.create_classifier and is documented below:

Metalhead.Layers.create_classifierFunction
create_classifier(inplanes::Integer, nclasses::Integer, activation = identity;
                  use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), 
                  dropout_prob = nothing)

Creates a classifier head to be used for models.

Arguments

  • inplanes: number of input feature maps
  • nclasses: number of output classes
  • activation: activation function to use
  • use_conv: whether to use a 1x1 convolutional layer instead of a Dense layer.
  • pool_layer: pooling layer to use. This is passed in with the layer instantiated with any arguments that are needed i.e. as AdaptiveMeanPool((1, 1)), for example.
  • dropout_prob: dropout probability used in the classifier head. Set to nothing to disable dropout.
source
create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer,
                  activations::NTuple{2} = (relu, identity);
                  use_conv::NTuple{2, Bool} = (false, false),
                  pool_layer = AdaptiveMeanPool((1, 1)), dropout_prob = nothing)

Creates a classifier head to be used for models with an extra hidden layer.

Arguments

  • inplanes: number of input feature maps
  • hidden_planes: number of hidden feature maps
  • nclasses: number of output classes
  • activations: activation functions to use for the hidden and output layers. This is a tuple of two elements, the first being the activation function for the hidden layer and the second for the output layer.
  • use_conv: whether to use a 1x1 convolutional layer instead of a Dense layer. This is a tuple of two booleans, the first for the hidden layer and the second for the output layer.
  • pool_layer: pooling layer to use. This is passed in with the layer instantiated with any arguments that are needed i.e. as AdaptiveMeanPool((1, 1)), for example.
  • dropout_prob: dropout probability used in the classifier head. Set to nothing to disable dropout.
source

Due to the power of multiple dispatch in Julia, the above function can be called with two different signatures - one of which creates a classifier with no hidden layers, and the other which creates a classifier with a single hidden layer. The function signature for both is documented above, and the user can choose the one that is most convenient for them. Both are used in Metalhead.jl - the latter is used in MobileNetv3, and the former is used almost everywhere else.