Vision Transformer models

This is the API reference for the Vision Transformer models supported by Metalhead.jl.

The higher-level model constructors

Metalhead.ViTType
ViT(config::Symbol = base; imsize::Dims{2} = (224, 224), inchannels::Integer = 3,
    patch_size::Dims{2} = (16, 16), pool = :class, nclasses::Integer = 1000)

Creates a Vision Transformer (ViT) model. (reference).

Arguments

  • config: the model configuration, one of [:tiny, :small, :base, :large, :huge, :giant, :gigantic]
  • imsize: image size
  • inchannels: number of input channels
  • patch_size: size of the patches
  • pool: pooling type, either :class or :mean
  • nclasses: number of classes in the output

See also Metalhead.vit.

source

The mid-level functions

Metalhead.vitFunction
vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, patch_size::Dims{2} = (16, 16),
    embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_prob = 0.1,
    emb_dropout_prob = 0.1, pool = :class, nclasses::Integer = 1000)

Creates a Vision Transformer (ViT) model. (reference).

Arguments

  • imsize: image size
  • inchannels: number of input channels
  • patch_size: size of the patches
  • embedplanes: the number of channels after the patch embedding
  • depth: number of blocks in the transformer
  • nheads: number of attention heads in the transformer
  • mlpplanes: number of hidden channels in the MLP block in the transformer
  • dropout_prob: dropout probability
  • emb_dropout: dropout probability for the positional embedding layer
  • pool: pooling type, either :class or :mean
  • nclasses: number of classes in the output
source