Layers
Metalhead also defines a module called Layers
which contains some custom layers that are used to configure the models in Metalhead. These layers are not available in Flux at present. To use the functions defined in the Layers
module, you need to import it.
using Metalhead: Layers
This page contains the API reference for the Layers
module.
The Layers
module is still a work in progress. While we will endeavour to keep the API stable, we cannot guarantee that it will not change in the future. If you find any of the functions in this module do not work as expected, please open an issue on GitHub.
Convolution + BatchNorm layers
Metalhead.Layers.conv_norm
— Functionconv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
activation = relu; norm_layer = BatchNorm, revnorm::Bool = false,
preact::Bool = false, stride::Integer = 1, pad::Integer = 0,
dilation::Integer = 1, groups::Integer = 1, [bias, weight, init])
Create a convolution + normalisation layer pair with activation.
Arguments
kernel_size
: size of the convolution kernel (tuple)inplanes
: number of input feature mapsoutplanes
: number of output feature mapsactivation
: the activation function for the final layernorm_layer
: the normalisation layer used. Note that usingidentity
as the normalisation layer will result in no normalisation being applied. (This is only compatible withpreact
andrevnorm
both set tofalse
.)revnorm
: set totrue
to place the normalisation layer before the convolutionpreact
: set totrue
to place the activation function before the normalisation layer (only compatible withrevnorm = false
)bias
: bias for the convolution kernel. This is set tofalse
by default ifnorm_layer
is notidentity
andtrue
otherwise.stride
: stride of the convolution kernelpad
: padding of the convolution kerneldilation
: dilation of the convolution kernelgroups
: groups for the convolution kernelweight
,init
: initialization for the convolution kernel (seeFlux.Conv
)
Metalhead.Layers.basic_conv_bn
— Functionbasic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu;
kwargs...)
Returns a convolution + batch normalisation pair with activation as used by the Inception family of models with default values matching those used in the official TensorFlow implementation.
Arguments
kernel_size
: size of the convolution kernel (tuple)inplanes
: number of input feature mapsoutplanes
: number of output feature mapsactivation
: the activation function for the final layerbatchnorm
: set totrue
to include batch normalization after each convolutionkwargs
: keyword arguments passed toconv_norm
Convolution-related custom blocks
These blocks are designed to be used in convolutional neural networks. Most of these are used in the MobileNet and EfficientNet family of models, but they also feature in "fancier" versions of well known-models like ResNet (SE-ResNet).
Metalhead.Layers.dwsep_conv_norm
— Functiondwsep_conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer,
activation = relu; norm_layer = BatchNorm, stride::Integer = 1,
bias::Bool = !(norm_layer !== identity), pad::Integer = 0, [bias, weight, init])
Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers:
- a
kernel_size
depthwise convolution frominplanes => inplanes
- a (batch) normalisation layer +
activation
(ifnorm_layer !== identity
; otherwiseactivation
is applied to the convolution output) - a
kernel_size
convolution frominplanes => outplanes
- a (batch) normalisation layer +
activation
(ifnorm_layer !== identity
; otherwiseactivation
is applied to the convolution output)
See Fig. 3 in reference.
Arguments
kernel_size
: size of the convolution kernel (tuple)inplanes
: number of input feature mapsoutplanes
: number of output feature mapsactivation
: the activation function for the final layernorm_layer
: the normalisation layer used. Note that usingidentity
as the normalisation layer will result in no normalisation being applied.bias
: whether to use bias in the convolution layers.stride
: stride of the first convolution kernelpad
: padding of the first convolution kernelweight
,init
: initialization for the convolution kernel (seeFlux.Conv
)
Metalhead.Layers.mbconv
— Functionmbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer,
outplanes::Integer, activation = relu; stride::Integer,
reduction::Union{Nothing, Real} = nothing,
se_round_fn = x -> round(Int, x), norm_layer = BatchNorm, kwargs...)
Create a basic inverted residual block for MobileNet and Efficient variants. This is a sequence of layers:
a 1x1 convolution from
inplanes => explanes
followed by a (batch) normalisation layeractivation
ifinplanes != explanes
a
kernel_size
depthwise separable convolution fromexplanes => explanes
a (batch) normalisation layer
a squeeze-and-excitation block (if
reduction != nothing
) fromexplanes => se_round_fn(explanes / reduction)
and back toexplanes
a 1x1 convolution from
explanes => outplanes
a (batch) normalisation layer +
activation
This function does not handle the residual connection by default. The user must add this manually to use this block as a standalone. To construct a model, check out the builders, which handle the residual connection and other details.
First introduced in the MobileNetv2 paper. (See Fig. 3 in reference.)
Arguments
kernel_size
: kernel size of the convolutional layersinplanes
: number of input feature mapsexplanes
: The number of expanded feature maps. This is the number of feature maps after the first 1x1 convolution.outplanes
: The number of output feature mapsactivation
: The activation function for the first two convolution layerstride
: The stride of the convolutional kernel, has to be either 1 or 2reduction
: The reduction factor for the number of hidden feature maps in a squeeze and excite layer (seesqueeze_excite
)se_round_fn
: The function to round the number of reduced feature maps in the squeeze and excite layernorm_layer
: The normalization layer to use
Metalhead.Layers.fused_mbconv
— Functionfused_mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer,
outplanes::Integer, activation = relu;
stride::Integer, norm_layer = BatchNorm)
Create a fused inverted residual block.
This is a sequence of layers:
- a
kernel_size
depthwise separable convolution fromexplanes => explanes
- a (batch) normalisation layer
- a 1x1 convolution from
explanes => outplanes
followed by a (batch) normalisation layer +activation
ifinplanes != explanes
This function does not handle the residual connection by default. The user must add this manually to use this block as a standalone. To construct a model, check out the builders, which handle the residual connection and other details.
Originally introduced by Google in EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML. Later used in the EfficientNetv2 paper.
Arguments
kernel_size
: kernel size of the convolutional layersinplanes
: number of input feature mapsexplanes
: The number of expanded feature mapsoutplanes
: The number of output feature mapsactivation
: The activation function for the first two convolution layerstride
: The stride of the convolutional kernel, has to be either 1 or 2norm_layer
: The normalization layer to use
Metalhead.Layers.squeeze_excite
— Functionsqueeze_excite(inplanes::Integer; reduction::Real = 16, round_fn = _round_channels,
norm_layer = identity, activation = relu, gate_activation = sigmoid)
Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE-ResNets.
Arguments
inplanes
: The number of input feature mapsreduction
: The reduction factor for the number of hidden feature maps in the squeeze and excite layer. The number of hidden feature maps is calculated asround_fn(inplanes / reduction)
.round_fn
: The function to round the number of reduced feature maps.activation
: The activation function for the first convolution layergate_activation
: The activation function for the gate layernorm_layer
: The normalization layer to be used after the convolution layersrd_planes
: The number of hidden feature maps in a squeeze and excite layer
Metalhead.Layers.effective_squeeze_excite
— Functioneffective_squeeze_excite(inplanes, gate_activation = sigmoid)
Effective squeeze-and-excitation layer. (reference: CenterMask : Real-Time Anchor-Free Instance Segmentation)
Arguments
inplanes
: The number of input feature mapsgate_activation
: The activation function for the gate layer
Normalisation, Dropout and Pooling layers
Metalhead provides various custom layers for normalisation, dropout and pooling which have been used to additionally customise various models.
Normalisation layers
Metalhead.Layers.ChannelLayerNorm
— TypeChannelLayerNorm(sz::Integer, λ = identity; eps = 1.0f-6)
A variant of LayerNorm where the input is normalised along the channel dimension. The input is expected to have channel dimension with size sz
. It also applies a learnable shift and rescaling after the normalization.
Note that this is specifically for inputs with 4 dimensions in the format (H, W, C, N) where H, W are the height and width of the input, C is the number of channels, and N is the batch size.
Metalhead.Layers.LayerNormV2
— TypeLayerNormV2(size..., λ=identity; affine=true, eps=1f-5)
Same as Flux's LayerNorm but eps is added before taking the square root in the denominator. Therefore, LayerNormV2 matches pytorch's LayerNorm.
Metalhead.Layers.LayerScale
— FunctionLayerScale(planes::Integer, λ)
Creates a Flux.Scale
layer that performs "LayerScale
" (reference).
Arguments
planes
: Size of channel dimension in the input.λ
: initialisation value for the learnable diagonal matrix.
Dropout layers
Metalhead.Layers.DropBlock
— TypeDropBlock(drop_block_prob = 0.1, block_size = 7, gamma_scale = 1.0, [rng])
The DropBlock
layer. While training, it zeroes out continguous regions of size block_size
in the input. During inference, it simply returns the input x
. It can be used in two ways: either with all blocks having the same survival probability or with a linear scaling rule across the blocks. This is performed only at training time. At test time, the DropBlock
layer is equivalent to identity
.
Arguments
drop_block_prob
: probability of dropping a block. Ifnothing
is passed, it returnsidentity
. Note that some literature uses the term "survival probability" instead, which is equivalent to1 - drop_block_prob
.block_size
: size of the block to dropgamma_scale
: multiplicative factor forgamma
used. For the calculation of gamma, refer to the paper.rng
: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU.
Metalhead.Layers.dropblock
— Functiondropblock([rng], x::AbstractArray{T, 4}, drop_block_prob, block_size,
gamma_scale, active::Bool = true)
The dropblock function. If active
is true
, for each input, it zeroes out continguous regions of size block_size
in the input. Otherwise, it simply returns the input x
.
Arguments
rng
: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU.x
: input arraydrop_block_prob
: probability of dropping a block. Ifnothing
is passed, it returnsidentity
.block_size
: size of the block to dropgamma_scale
: multiplicative factor forgamma
used. For the calculations, refer to the paper.
If you are not a package developer, you most likely do not want this function. Use DropBlock
instead.
Metalhead.Layers.StochasticDepth
— FunctionStochasticDepth(p, mode = :row; [rng])
Implements Stochastic Depth. This is a Dropout
layer from Flux that drops values with probability p
. (reference)
This layer can be used to drop certain blocks in a residual structure and allow them to propagate completely through the skip connection. It can be used in two ways: either with all blocks having the same survival probability or with a linear scaling rule across the blocks. This is performed only at training time. At test time, the StochasticDepth
layer is equivalent to identity
.
Arguments
p
: probability of Stochastic Depth. Note that some literature uses the term "survival probability" instead, which is equivalent to1 - p
.mode
: Either:batch
or:row
.:batch
randomly zeroes the entire input,row
zeroes randomly selected rows from the batch. The default is:row
.rng
: can be used to pass in a custom RNG instead of the default. SeeFlux.Dropout
for more information on the behaviour of this argument. Custom RNGs are only supported on the CPU.
Pooling layers
Metalhead.Layers.AdaptiveMeanMaxPool
— FunctionAdaptiveMeanMaxPool([connection = +], output_size::Tuple = (1, 1))
A type of adaptive pooling layer which uses both mean and max pooling and combines them to produce a single output. Note that this is equivalent to Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size))
. When connection
is not specified, it defaults to +
.
Arguments
connection
: The connection type to use.output_size
: The size of the output after pooling.
Classifier creation
Metalhead provides a function to create a classifier for neural network models that is quite flexible, and is used by the library extensively to create the classifier "head" for networks.
Metalhead.Layers.create_classifier
— Functioncreate_classifier(inplanes::Integer, nclasses::Integer, activation = identity;
use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)),
dropout_prob = nothing)
Creates a classifier head to be used for models.
Arguments
inplanes
: number of input feature mapsnclasses
: number of output classesactivation
: activation function to useuse_conv
: whether to use a 1x1 convolutional layer instead of aDense
layer.pool_layer
: pooling layer to use. This is passed in with the layer instantiated with any arguments that are needed i.e. asAdaptiveMeanPool((1, 1))
, for example.dropout_prob
: dropout probability used in the classifier head. Set tonothing
to disable dropout.
create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer,
activations::NTuple{2} = (relu, identity);
use_conv::NTuple{2, Bool} = (false, false),
pool_layer = AdaptiveMeanPool((1, 1)), dropout_prob = nothing)
Creates a classifier head to be used for models with an extra hidden layer.
Arguments
inplanes
: number of input feature mapshidden_planes
: number of hidden feature mapsnclasses
: number of output classesactivations
: activation functions to use for the hidden and output layers. This is a tuple of two elements, the first being the activation function for the hidden layer and the second for the output layer.use_conv
: whether to use a 1x1 convolutional layer instead of aDense
layer. This is a tuple of two booleans, the first for the hidden layer and the second for the output layer.pool_layer
: pooling layer to use. This is passed in with the layer instantiated with any arguments that are needed i.e. asAdaptiveMeanPool((1, 1))
, for example.dropout_prob
: dropout probability used in the classifier head. Set tonothing
to disable dropout.
Vision transformer-related layers
The Layers
module contains specific layers that are used to build vision transformer (ViT)-inspired models:
Metalhead.Layers.MultiHeadSelfAttention
— TypeMultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
attn_dropout_prob = 0., proj_dropout_prob = 0.)
Multi-head self-attention layer.
Arguments
planes
: number of input channelsnheads
: number of headsqkv_bias
: whether to use bias in the layer to get the query, key and valueattn_dropout_prob
: dropout probability after the self-attention layerproj_dropout_prob
: dropout probability after the projection layer
Metalhead.Layers.ClassTokens
— TypeClassTokens(planes::Integer; init = Flux.zeros32)
Appends class tokens to an input with embedding dimension planes
for use in many vision transformer models.
Metalhead.Layers.ViPosEmbedding
— TypeViPosEmbedding(embedsize::Integer, npatches::Integer;
init = (dims::Dims{2}) -> rand(Float32, dims))
Positional embedding layer used by many vision transformer-like models.
Metalhead.Layers.PatchEmbedding
— FunctionPatchEmbedding(imsize::Dims{2} = (224, 224); inchannels::Integer = 3,
patch_size::Dims{2} = (16, 16), embedplanes = 768,
norm_layer = planes -> identity, flatten = true)
Patch embedding layer used by many vision transformer-like models to split the input image into patches.
Arguments
imsize
: the size of the input imageinchannels
: number of input channelspatch_size
: the size of the patchesembedplanes
: the number of channels in the embeddingnorm_layer
: the normalization layer - by default the identity function but otherwise takes a single argument constructor for a normalization layer like LayerNorm or BatchNormflatten
: set true to flatten the input spatial dimensions after the embedding
MLPMixer-related blocks
Apart from this, the Layers
module also contains certain blocks used in MLPMixer-style models:
Metalhead.Layers.gated_mlp_block
— Functiongated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer,
outplanes::Integer = inplanes; dropout_prob = 0.0, activation = gelu)
Feedforward block based on the implementation in the paper "Pay Attention to MLPs". (reference)
Arguments
gate_layer
: Layer to use for the gating.inplanes
: Number of dimensions in the input.hidden_planes
: Number of dimensions in the intermediate layer.outplanes
: Number of dimensions in the output - by default it is the same asinplanes
.dropout_prob
: Dropout probability.activation
: Activation function to use.
Metalhead.Layers.mlp_block
— Functionmlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes;
dropout_prob = 0., activation = gelu)
Feedforward block used in many MLPMixer-like and vision-transformer models.
Arguments
inplanes
: Number of dimensions in the input.hidden_planes
: Number of dimensions in the intermediate layer.outplanes
: Number of dimensions in the output - by default it is the same asinplanes
.dropout_prob
: Dropout probability.activation
: Activation function to use.
Utilities for layers
These are some miscellaneous utilities present in the Layers
module, and are used with other custom/inbuilt layers to make certain common operations in neural networks easier.
Metalhead.Layers.inputscale
— Functioninputscale(λ; activation = identity)
Scale the input by a scalar λ
and applies an activation function to it. Equivalent to activation.(λ .* x)
.
Metalhead.Layers.actadd
— Functionactadd(activation = relu, xs...)
Convenience function for adding input arrays after applying an activation function to them. Useful as the connection
argument for the block function in resnet
.
Metalhead.Layers.addact
— Functionaddact(activation = relu, xs...)
Convenience function for applying an activation function to the output after summing up the input arrays. Useful as the connection
argument for the block function in resnet
.
Metalhead.Layers.cat_channels
— Functioncat_channels(x, y, zs...)
Concatenate x
and y
(and any z
s) along the channel dimension (third dimension). Equivalent to cat(x, y, zs...; dims=3)
. Convenient reduction operator for use with Parallel
.
Metalhead.Layers.flatten_chains
— Functionflatten_chains(m::Chain)
flatten_chains(m)
Convenience function for traversing nested layers of a Chain object and flatten them into a single iterator.
Metalhead.Layers.linear_scheduler
— Functionlinear_scheduler(drop_prob = 0.0; start_value = 0.0, depth)
linear_scheduler(drop_prob::Nothing; depth::Integer)
Returns the dropout probabilities for a given depth using the linear scaling rule. Note that this returns evenly spaced values between start_value
and drop_prob
, not including drop_prob
. If drop_prob
is nothing
, it returns a Vector
of length depth
with all values equal to nothing
.
Metalhead.Layers.swapdims
— Functionswapdims(perm)
Convenience function that returns a closure which permutes the dimensions of an array. perm
is a vector or tuple specifying a permutation of the input dimensions. Equivalent to permutedims(x, perm)
.