More advanced layers

This page contains the API reference for some more advanced layers present in the Layers module. These layers are used in Metalhead.jl to build more complex models, and can also be used by the user to build custom models. For a more basic introduction to the Layers module, please refer to the introduction guide for the Layers module.

Squeeze-and-excitation blocks

These are used in models like SE-ResNet and SE-ResNeXt, as well as in the design of inverted residual blocks used in the MobileNet and EfficientNet family of models.

Metalhead.Layers.squeeze_exciteFunction
squeeze_excite(inplanes::Integer; reduction::Real = 16, round_fn = _round_channels, 
               norm_layer = identity, activation = relu, gate_activation = sigmoid)

Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE-ResNets.

Arguments

  • inplanes: The number of input feature maps
  • reduction: The reduction factor for the number of hidden feature maps in the squeeze and excite layer. The number of hidden feature maps is calculated as round_fn(inplanes / reduction).
  • round_fn: The function to round the number of reduced feature maps.
  • activation: The activation function for the first convolution layer
  • gate_activation: The activation function for the gate layer
  • norm_layer: The normalization layer to be used after the convolution layers
  • rd_planes: The number of hidden feature maps in a squeeze and excite layer
source

Inverted residual blocks

These blocks are designed to be used in the MobileNet and EfficientNet family of convolutional neural networks.

Metalhead.Layers.dwsep_conv_normFunction
dwsep_conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
                activation = relu; norm_layer = BatchNorm, stride::Integer = 1,
                bias::Bool = !(norm_layer !== identity), pad::Integer = 0, [bias, weight, init])

Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers:

  • a kernel_size depthwise convolution from inplanes => inplanes
  • a (batch) normalisation layer + activation (if norm_layer !== identity; otherwise activation is applied to the convolution output)
  • a kernel_size convolution from inplanes => outplanes
  • a (batch) normalisation layer + activation (if norm_layer !== identity; otherwise activation is applied to the convolution output)

See Fig. 3 in reference.

Arguments

  • kernel_size: size of the convolution kernel (tuple)
  • inplanes: number of input feature maps
  • outplanes: number of output feature maps
  • activation: the activation function for the final layer
  • norm_layer: the normalisation layer used. Note that using identity as the normalisation layer will result in no normalisation being applied.
  • bias: whether to use bias in the convolution layers.
  • stride: stride of the first convolution kernel
  • pad: padding of the first convolution kernel
  • weight, init: initialization for the convolution kernel (see Flux.Conv)
source
Metalhead.Layers.mbconvFunction
mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer,
       outplanes::Integer, activation = relu; stride::Integer,
       reduction::Union{Nothing, Real} = nothing,
       se_round_fn = x -> round(Int, x), norm_layer = BatchNorm, kwargs...)

Create a basic inverted residual block for MobileNet and Efficient variants. This is a sequence of layers:

  • a 1x1 convolution from inplanes => explanes followed by a (batch) normalisation layer

  • activation if inplanes != explanes

  • a kernel_size depthwise separable convolution from explanes => explanes

  • a (batch) normalisation layer

  • a squeeze-and-excitation block (if reduction != nothing) from explanes => se_round_fn(explanes / reduction) and back to explanes

  • a 1x1 convolution from explanes => outplanes

  • a (batch) normalisation layer + activation

Warning

This function does not handle the residual connection by default. The user must add this manually to use this block as a standalone. To construct a model, check out the builders, which handle the residual connection and other details.

First introduced in the MobileNetv2 paper. (See Fig. 3 in reference.)

Arguments

  • kernel_size: kernel size of the convolutional layers
  • inplanes: number of input feature maps
  • explanes: The number of expanded feature maps. This is the number of feature maps after the first 1x1 convolution.
  • outplanes: The number of output feature maps
  • activation: The activation function for the first two convolution layer
  • stride: The stride of the convolutional kernel, has to be either 1 or 2
  • reduction: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (see squeeze_excite)
  • se_round_fn: The function to round the number of reduced feature maps in the squeeze and excite layer
  • norm_layer: The normalization layer to use
source
Metalhead.Layers.fused_mbconvFunction
fused_mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer,
             outplanes::Integer, activation = relu;
             stride::Integer, norm_layer = BatchNorm)

Create a fused inverted residual block.

This is a sequence of layers:

  • a kernel_size depthwise separable convolution from explanes => explanes
  • a (batch) normalisation layer
  • a 1x1 convolution from explanes => outplanes followed by a (batch) normalisation layer + activation if inplanes != explanes
Warning

This function does not handle the residual connection by default. The user must add this manually to use this block as a standalone. To construct a model, check out the builders, which handle the residual connection and other details.

Originally introduced by Google in EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML. Later used in the EfficientNetv2 paper.

Arguments

  • kernel_size: kernel size of the convolutional layers
  • inplanes: number of input feature maps
  • explanes: The number of expanded feature maps
  • outplanes: The number of output feature maps
  • activation: The activation function for the first two convolution layer
  • stride: The stride of the convolutional kernel, has to be either 1 or 2
  • norm_layer: The normalization layer to use
source

The Layers module contains specific layers that are used to build vision transformer (ViT)-inspired models:

Metalhead.Layers.MultiHeadSelfAttentionType
MultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, 
            attn_dropout_prob = 0., proj_dropout_prob = 0.)

Multi-head self-attention layer.

Arguments

  • planes: number of input channels
  • nheads: number of heads
  • qkv_bias: whether to use bias in the layer to get the query, key and value
  • attn_dropout_prob: dropout probability after the self-attention layer
  • proj_dropout_prob: dropout probability after the projection layer
source
Metalhead.Layers.ClassTokensType
ClassTokens(planes::Integer; init = Flux.zeros32)

Appends class tokens to an input with embedding dimension planes for use in many vision transformer models.

source
Metalhead.Layers.ViPosEmbeddingType
ViPosEmbedding(embedsize::Integer, npatches::Integer; 
               init = (dims::Dims{2}) -> rand(Float32, dims))

Positional embedding layer used by many vision transformer-like models.

source
Metalhead.Layers.PatchEmbeddingFunction
PatchEmbedding(imsize::Dims{2} = (224, 224); inchannels::Integer = 3,
               patch_size::Dims{2} = (16, 16), embedplanes = 768,
               norm_layer = planes -> identity, flatten = true)

Patch embedding layer used by many vision transformer-like models to split the input image into patches.

Arguments

  • imsize: the size of the input image
  • inchannels: number of input channels
  • patch_size: the size of the patches
  • embedplanes: the number of channels in the embedding
  • norm_layer: the normalization layer - by default the identity function but otherwise takes a single argument constructor for a normalization layer like LayerNorm or BatchNorm
  • flatten: set true to flatten the input spatial dimensions after the embedding
source

Apart from this, the Layers module also contains certain blocks used in MLPMixer-style models:

Metalhead.Layers.gated_mlp_blockFunction
gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, 
          outplanes::Integer = inplanes; dropout_prob = 0.0, activation = gelu)

Feedforward block based on the implementation in the paper "Pay Attention to MLPs". (reference)

Arguments

  • gate_layer: Layer to use for the gating.
  • inplanes: Number of dimensions in the input.
  • hidden_planes: Number of dimensions in the intermediate layer.
  • outplanes: Number of dimensions in the output - by default it is the same as inplanes.
  • dropout_prob: Dropout probability.
  • activation: Activation function to use.
source
Metalhead.Layers.mlp_blockFunction
mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; 
          dropout_prob = 0., activation = gelu)

Feedforward block used in many MLPMixer-like and vision-transformer models.

Arguments

  • inplanes: Number of dimensions in the input.
  • hidden_planes: Number of dimensions in the intermediate layer.
  • outplanes: Number of dimensions in the output - by default it is the same as inplanes.
  • dropout_prob: Dropout probability.
  • activation: Activation function to use.
source

Miscellaneous utilities for layers

These are some miscellaneous utilities present in the Layers module, and are used with other custom/inbuilt layers to make certain common operations in neural networks easier.

Metalhead.Layers.inputscaleFunction
inputscale(λ; activation = identity)

Scale the input by a scalar λ and applies an activation function to it. Equivalent to activation.(λ .* x).

source
Metalhead.Layers.actaddFunction
actadd(activation = relu, xs...)

Convenience function for summing up the input arrays after applying an activation function to them. Useful as the connection argument for the block function in Metalhead.resnet.

source
Metalhead.Layers.addactFunction
addact(activation = relu, xs...)

Convenience function for applying an activation function to the output after summing up the input arrays. Useful as the connection argument for the block function in Metalhead.resnet.

source
Metalhead.Layers.cat_channelsFunction
cat_channels(x, y, zs...)

Concatenate x and y (and any zs) along the channel dimension (third dimension). Equivalent to cat(x, y, zs...; dims=3). Convenient reduction operator for use with Parallel.

source
Metalhead.Layers.flatten_chainsFunction
flatten_chains(m::Chain)
flatten_chains(m)

Convenience function for traversing nested layers of a Chain object and flatten them into a single iterator.

source
Metalhead.Layers.swapdimsFunction
swapdims(perm)

Convenience function that returns a closure which permutes the dimensions of an array. perm is a vector or tuple specifying a permutation of the input dimensions. Equivalent to permutedims(x, perm).

source