ResNet-like models
This is the API reference for the ResNet inspired model structures present in Metalhead.jl.
The higher-level model constructors
Metalhead.ResNet
— TypeResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
Creates a ResNet model with the specified depth. (reference)
Arguments
depth
: one of[18, 34, 50, 101, 152]
. The depth of the ResNet model.pretrain
: set totrue
to load the model with pre-trained weights for ImageNetinchannels
: The number of input channels.nclasses
: the number of output classes
Advanced users who want more configuration options will be better served by using resnet
.
Metalhead.WideResNet
— TypeWideResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
Creates a Wide ResNet model with the specified depth. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same. (reference)
Arguments
depth
: one of[18, 34, 50, 101, 152]
. The depth of the Wide ResNet model.pretrain
: set totrue
to load the model with pre-trained weights for ImageNetinchannels
: The number of input channels.nclasses
: The number of output classes
Advanced users who want more configuration options will be better served by using resnet
.
Metalhead.ResNeXt
— TypeResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
Creates a ResNeXt model with the specified depth, cardinality, and base width. (reference)
Arguments
depth
: one of[50, 101, 152]
. The depth of the ResNeXt model.pretrain
: set totrue
to load the model with pre-trained weights for ImageNet. Supported configurations are:- depth 50, cardinality of 32 and base width of 4.
- depth 101, cardinality of 32 and base width of 8.
- depth 101, cardinality of 64 and base width of 4.
cardinality
: the number of groups to be used in the 3x3 convolution in each block.base_width
: the number of feature maps in each group.inchannels
: the number of input channels.nclasses
: the number of output classes
Advanced users who want more configuration options will be better served by using resnet
.
Metalhead.SEResNet
— TypeSEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
Creates a SEResNet model with the specified depth. (reference)
Arguments
depth
: one of[18, 34, 50, 101, 152]
. The depth of the SEResNet model.pretrain
: set totrue
to load the model with pre-trained weights for ImageNetinchannels
: the number of input channels.nclasses
: the number of output classes
SEResNet
does not currently support pretrained weights.
Advanced users who want more configuration options will be better served by using resnet
.
Metalhead.SEResNeXt
— TypeSEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
Creates a SEResNeXt model with the specified depth, cardinality, and base width. (reference)
Arguments
depth
: one of[50, 101, 152]
. The depth of the SEResNeXt model.pretrain
: set totrue
to load the model with pre-trained weights for ImageNetcardinality
: the number of groups to be used in the 3x3 convolution in each block.base_width
: the number of feature maps in each group.inchannels
: the number of input channelsnclasses
: the number of output classes
SEResNeXt
does not currently support pretrained weights.
Advanced users who want more configuration options will be better served by using resnet
.
Metalhead.Res2Net
— TypeRes2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
base_width::Integer = 26, inchannels::Integer = 3,
nclasses::Integer = 1000)
Creates a Res2Net model with the specified depth, scale, and base width. (reference)
Arguments
depth
: one of[50, 101, 152]
. The depth of the Res2Net model.pretrain
: set totrue
to load the model with pre-trained weights for ImageNetscale
: the number of feature groups in the block. See the paper for more details.base_width
: the number of feature maps in each group.inchannels
: the number of input channels.nclasses
: the number of output classes
Res2Net
does not currently support pretrained weights.
Advanced users who want more configuration options will be better served by using resnet
.
Metalhead.Res2NeXt
— TypeRes2NeXt(depth::Integer; pretrain::Bool = false, scale::Integer = 4,
base_width::Integer = 4, cardinality::Integer = 8,
inchannels::Integer = 3, nclasses::Integer = 1000)
Creates a Res2NeXt model with the specified depth, scale, base width and cardinality. (reference)
Arguments
depth
: one of[50, 101, 152]
. The depth of the Res2Net model.pretrain
: set totrue
to load the model with pre-trained weights for ImageNetscale
: the number of feature groups in the block. See the paper for more details.base_width
: the number of feature maps in each group.cardinality
: the number of groups in the 3x3 convolutions.inchannels
: the number of input channels.nclasses
: the number of output classes
Res2NeXt
does not currently support pretrained weights.
Advanced users who want more configuration options will be better served by using resnet
.
The mid-level function
Metalhead.resnet
— Functionresnet(block_type, block_repeats::AbstractVector{<:Integer},
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity);
cardinality::Integer = 1, base_width::Integer = 64,
inplanes::Integer = 64, reduction_factor::Integer = 1,
connection = addact, activation = relu,
norm_layer = BatchNorm, revnorm::Bool = false,
attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)),
use_conv::Bool = false, dropblock_prob = nothing,
stochastic_depth_prob = nothing, dropout_prob = nothing,
imsize::Dims{2} = (256, 256), inchannels::Integer = 3,
nclasses::Integer = 1000, kwargs...)
Creates a generic ResNet-like model that is used to create The higher-level model constructors like ResNet, Wide ResNet, ResNeXt and Res2Net. For an even more generic model API, see Metalhead.build_resnet
.
Arguments
block_type
: The type of block to be used in the model. This can be one ofMetalhead.basicblock
,Metalhead.bottleneck
andMetalhead.bottle2neck
.basicblock
is used in the original ResNet paper for ResNet-18 and ResNet-34, andbottleneck
is used in the original ResNet-50 and ResNet-101 models, as well as for the Wide ResNet and ResNeXt models.bottle2neck
is introduced in theRes2Net
paper.block_repeats
: AVector
of integers specifying the number of times each block is repeated in each stage of the ResNet model. For example,[3, 4, 6, 3]
is the configuration used in ResNet-50, which has 3 blocks in the first stage, 4 blocks in the second stage, 6 blocks in the third stage and 3 blocks in the fourth stage.downsample_opt
: ANTuple
of two callbacks that are used to determine the downsampling operation to be used in the model. The first callback is used to determine the convolutional operation to be used in the downsampling operation and the second callback is used to determine the identity operation to be used in the downsampling operation.cardinality
: The number of groups to be used in the 3x3 convolutional layer in the bottleneck block. This is usually modified from the default value of1
in the ResNet models to32
or64
in theResNeXt
models.base_width
: The base width of the convolutional layer in the blocks of the model.inplanes
: The number of input channels in the first convolutional layer.reduction_factor
: The reduction factor used in the model.connection
: This is a function that determines the residual connection in the model. Forresnets
, either ofMetalhead.Layers.addact
orMetalhead.Layers.actadd
is recommended. These decide whether the residual connection is added before or after the activation function.norm_layer
: The normalisation layer to be used in the model.revnorm
: set totrue
to place the normalisation layers before the convolutionsattn_fn
: A callback that is used to determine the attention function to be used in the model. SeeMetalhead.Layers.squeeze_excite
for an example.pool_layer
: A fully-instantiated pooling layer passed in to be used by the classifier head. For example,AdaptiveMeanPool((1, 1))
is used in the ResNet family by default, but something likeMeanPool((3, 3))
should also work provided the dimensions after applying the pooling layer are compatible with the rest of the classifier head.use_conv
: Set to true to use convolutions instead of identity operations in the model.dropblock_prob
:DropBlock
probability to be used in the model. Set tonothing
to disable DropBlock. SeeMetalhead.DropBlock
for more details.stochastic_depth_prob
:StochasticDepth
probability to be used in the model. Set tonothing
to disable StochasticDepth. SeeMetalhead.StochasticDepth
for more details.dropout_prob
:Dropout
probability to be used in the classifier head. Set tonothing
to disable Dropout.imsize
: The size of the input (height, width).inchannels
: The number of input channels.nclasses
: The number of output classes.kwargs
: Additional keyword arguments to be passed to the block builder (note: ignore this argument if you are not sure what it does. To know more about how this works, check out the section of the documentation that talks about builders in Metalhead and specifically for the ResNet block functions).
Lower-level functions and builders
Block functions
Metalhead.basicblock
— Functionbasicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
reduction_factor::Integer = 1, activation = relu,
norm_layer = BatchNorm, revnorm::Bool = false,
drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
Creates a basic residual block (see reference). This function creates the layers. For more configuration options and to see the function used to build the block for the model, see Metalhead.basicblock_builder
.
Arguments
inplanes
: number of input feature mapsplanes
: number of feature maps for the blockstride
: the stride of the blockreduction_factor
: the factor by which the input feature maps are reduced before
the first convolution.
activation
: the activation function to use.norm_layer
: the normalization layer to use.revnorm
: set totrue
to place the normalisation layer before the convolutiondrop_block
: the drop block layerdrop_path
: the drop path layerattn_fn
: the attention function to use. Seesqueeze_excite
for an example.
Metalhead.bottleneck
— Functionbottleneck(inplanes::Integer, planes::Integer; stride::Integer,
cardinality::Integer = 1, base_width::Integer = 64,
reduction_factor::Integer = 1, activation = relu,
norm_layer = BatchNorm, revnorm::Bool = false,
drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
Creates a bottleneck residual block (see reference). This function creates the layers. For more configuration options and to see the function used to build the block for the model, see Metalhead.bottleneck_builder
.
Arguments
inplanes
: number of input feature mapsplanes
: number of feature maps for the blockstride
: the stride of the blockcardinality
: the number of groups in the convolution.base_width
: the number of output feature maps for each convolutional group.reduction_factor
: the factor by which the input feature maps are reduced before the first convolution.activation
: the activation function to use.norm_layer
: the normalization layer to use.revnorm
: set totrue
to place the normalisation layer before the convolutiondrop_block
: the drop block layerdrop_path
: the drop path layerattn_fn
: the attention function to use. Seesqueeze_excite
for an example.
Metalhead.bottle2neck
— Functionbottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1,
cardinality::Integer = 1, base_width::Integer = 26,
scale::Integer = 4, activation = relu, norm_layer = BatchNorm,
revnorm::Bool = false, attn_fn = planes -> identity)
Creates a bottleneck block as described in the Res2Net paper. (reference) This function creates the layers. For more configuration options and to see the function used to build the block for the model, see Metalhead.bottle2neck_builder
.
Arguments
inplanes
: number of input feature mapsplanes
: number of feature maps for the blockstride
: the stride of the blockcardinality
: the number of groups in the 3x3 convolutions.base_width
: the number of output feature maps for each convolutional group.scale
: the number of feature groups in the block. See the paper for more details.activation
: the activation function to use.norm_layer
: the normalization layer to use.revnorm
: set totrue
to place the batch norm before the convolutionattn_fn
: the attention function to use. Seesqueeze_excite
for an example.
Downsampling functions
Metalhead.downsample_identity
— Functiondownsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
Creates an identity downsample layer. This returns identity
if inplanes == outplanes
. If outplanes > inplanes
, it maps the input to outplanes
channels using a 1x1 max pooling layer and zero padding.
This does not currently support the scenario where inplanes > outplanes
.
Arguments
inplanes
: number of input feature mapsoutplanes
: number of output feature maps
Note that kwargs are ignored and only included for compatibility with other downsample layers.
Metalhead.downsample_conv
— Functiondownsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1,
norm_layer = BatchNorm, revnorm::Bool = false)
Creates a 1x1 convolutional downsample layer as used in ResNet.
Arguments
inplanes
: number of input feature mapsoutplanes
: number of output feature mapsstride
: the stride of the convolutionnorm_layer
: the normalization layer to use.revnorm
: set totrue
to place the normalisation layer before the convolution
Metalhead.downsample_pool
— Functiondownsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1,
norm_layer = BatchNorm, revnorm::Bool = false)
Creates a pooling-based downsample layer as described in the Bag of Tricks paper. This adds an average pooling layer of size (2, 2)
with stride
followed by a 1x1 convolution.
Arguments
inplanes
: number of input feature mapsoutplanes
: number of output feature mapsstride
: the stride of the convolutionnorm_layer
: the normalization layer to use.revnorm
: set totrue
to place the normalisation layer before the convolution
Block builders
Metalhead.basicblock_builder
— Functionbasicblock_builder(block_repeats::AbstractVector{<:Integer};
inplanes::Integer = 64, reduction_factor::Integer = 1,
expansion::Integer = 1, norm_layer = BatchNorm,
revnorm::Bool = false, activation = relu,
attn_fn = planes -> identity,
dropblock_prob = nothing, stochastic_depth_prob = nothing,
stride_fn = resnet_stride, planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
Builder for creating a basic block for a ResNet model. (reference)
Arguments
block_repeats
: number of repeats of a block in each stageinplanes
: number of input channelsreduction_factor
: reduction factor for the number of channels in each stageexpansion
: expansion factor for the number of channels for the blocknorm_layer
: normalization layer to userevnorm
: set totrue
to place normalization layer before the convolutionactivation
: activation function to useattn_fn
: attention function to usedropblock_prob
: dropblock probability. Set tonothing
to disableDropBlock
stochastic_depth_prob
: stochastic depth probability. Set tonothing
to disableStochasticDepth
stride_fn
: callback for computing the stride of the blockplanes_fn
: callback for computing the number of channels in each blockdownsample_tuple
: two-element tuple of downsample functions to use. The first one is used when the number of channels changes in the block, the second one is used when the number of channels stays the same.
Metalhead.bottleneck_builder
— Functionbottleneck_builder(block_repeats::AbstractVector{<:Integer};
inplanes::Integer = 64, cardinality::Integer = 1,
base_width::Integer = 64, reduction_factor::Integer = 1,
expansion::Integer = 4, norm_layer = BatchNorm,
revnorm::Bool = false, activation = relu,
attn_fn = planes -> identity, dropblock_prob = nothing,
stochastic_depth_prob = nothing, stride_fn = resnet_stride,
planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
Builder for creating a bottleneck block for a ResNet/ResNeXt model. (reference)
Arguments
block_repeats
: number of repeats of a block in each stageinplanes
: number of input channelscardinality
: number of groups for the convolutional layerbase_width
: base width for the convolutional layerreduction_factor
: reduction factor for the number of channels in each stageexpansion
: expansion factor for the number of channels for the blocknorm_layer
: normalization layer to userevnorm
: set totrue
to place normalization layer before the convolutionactivation
: activation function to useattn_fn
: attention function to usedropblock_prob
: dropblock probability. Set tonothing
to disableDropBlock
stochastic_depth_prob
: stochastic depth probability. Set tonothing
to disableStochasticDepth
stride_fn
: callback for computing the stride of the blockplanes_fn
: callback for computing the number of channels in each blockdownsample_tuple
: two-element tuple of downsample functions to use. The first one is used when the number of channels changes in the block, the second one is used when the number of channels stays the same.
Metalhead.bottle2neck_builder
— Functionbottle2neck_builder(block_repeats::AbstractVector{<:Integer};
inplanes::Integer = 64, cardinality::Integer = 1,
base_width::Integer = 26, scale::Integer = 4,
expansion::Integer = 4, norm_layer = BatchNorm,
revnorm::Bool = false, activation = relu,
attn_fn = planes -> identity, stride_fn = resnet_stride,
planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
Builder for creating a bottle2neck block for a Res2Net model. (reference)
Arguments
block_repeats
: number of repeats of a block in each stageinplanes
: number of input channelscardinality
: number of groups for the convolutional layerbase_width
: base width for the convolutional layerscale
: scale for the number of channels in each blockexpansion
: expansion factor for the number of channels for the blocknorm_layer
: normalization layer to userevnorm
: set totrue
to place normalization layer before the convolutionactivation
: activation function to useattn_fn
: attention function to usestride_fn
: callback for computing the stride of the blockplanes_fn
: callback for computing the number of channels in each blockdownsample_tuple
: two-element tuple of downsample functions to use. The first one is used when the number of channels changes in the block, the second one is used when the number of channels stays the same.
Generic ResNet model builder
Metalhead.build_resnet
— Functionbuild_resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
connection, classifier_fn)
Creates a generic ResNet-like model.
This is a very generic, flexible but low level function that can be used to create any of the ResNet variants. For a more user friendly function, see Metalhead.resnet
.
Arguments
img_dims
: The dimensions of the input image. This is used to determine the number of feature maps to be passed to the classifier. This should be a tuple of the form(height, width, channels)
.stem
: The stem of the ResNet model. The stem should be created outside of this function and passed in as an argument. This is done to allow for more flexibility in creating the stem.resnet_stem
is a helper function that Metalhead provides which is recommended for creating the stem.get_layers
is a function that takes in two inputs - thestage_idx
, or the index of the stage, and theblock_idx
, or the index of the block within the stage. It returns a tuple of layers. If the tuple returned byget_layers
has more than one element, thenconnection
is used to splat this tuple intoParallel
- if not, then the only element of the tuple is directly inserted into the network.get_layers
is a very specific function and should not be created on its own. Instead, use one of the builders provided by Metalhead to create it.block_repeats
: This is aVector
of integers that specifies the number of repeats of each block in each stage.connection
: This is a function that determines the residual connection in the model. Forresnets
, either ofMetalhead.Layers.addact
orMetalhead.Layers.actadd
is recommended.classifier_fn
: This is a function that takes in the number of feature maps and returns a classifier. This is usually built as a closure using a function likeMetalhead.create_classifier
. For example, if the number of output classes isnclasses
, then the function can be defined aschannels -> create_classifier(channels, nclasses)
.
Utility callbacks
Metalhead.resnet_planes
— Functionresnet_planes(block_repeats::AbstractVector{<:Integer})
Default callback for determining the number of channels in each block in a ResNet model.
Arguments
block_repeats
: A Vector
of integers specifying the number of times each block is repeated in each stage of the ResNet model. For example, [3, 4, 6, 3]
is the configuration used in ResNet-50, which has 3 blocks in the first stage, 4 blocks in the second stage, 6 blocks in the third stage and 3 blocks in the fourth stage.
Metalhead.resnet_stride
— Functionresnet_stride(stage_idx::Integer, block_idx::Integer)
Default callback for determining the stride of a block in a ResNet model. Returns 2
for the first block in every stage except the first stage and 1
for all other blocks.
Arguments
stage_idx
: The index of the stage in the ResNet model.block_idx
: The index of the block in the stage.
Metalhead.resnet_stem
— Functionresnet_stem(; stem_type = :default, inchannels::Integer = 3, replace_stem_pool = false,
norm_layer = BatchNorm, activation = relu)
Builds a stem to be used in a ResNet model. See the stem
argument of resnet
for details on how to use this function.
Arguments
stem_type
: The type of stem to be built. One of[:default, :deep, :deep_tiered]
.:default
: Builds a stem based on the default ResNet stem, which consists of a single 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling layer with stride 2.:deep
: This borrows ideas from other papers (InceptionResNetv2, for example) in using a deeper stem with 3 successive 3x3 convolutions having normalisation layers after each one. This is followed by a 3x3 max pooling layer with stride 2.:deep_tiered
: A variant of the:deep
stem that has a larger width in the second convolution. This is an experimental variant from thetimm
library in Python that shows peformance improvements over the:deep
stem in some cases.
inchannels
: number of input channelsreplace_pool
: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two.norm_layer
: The normalisation layer used in the stem.activation
: The activation function used in the stem.