Working with pre-trained models from Metalhead

Using a model from Metalhead is as simple as selecting a model from the table of available models. For example, below we use the pre-trained ResNet-18 model.

using Metalhead

model = ResNet(18; pretrain = true);
ResNet(
  Chain(
    Chain(
      Chain(
        Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        MaxPool((3, 3), pad=1, stride=2),
      ),
      Chain(
        Parallel(
          addact(NNlib.relu, ...),
          identity,
          Chain(
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
            NNlib.relu,
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
          ),
        ),
        Parallel(
          addact(NNlib.relu, ...),
          identity,
          Chain(
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
            NNlib.relu,
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
          ),
        ),
      ),
      Chain(
        Parallel(
          addact(NNlib.relu, ...),
          Chain(
            Conv((1, 1), 64 => 128, stride=2, bias=false),  # 8_192 parameters
            BatchNorm(128),             # 256 parameters, plus 256
          ),
          Chain(
            Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false),  # 73_728 parameters
            BatchNorm(128),             # 256 parameters, plus 256
            NNlib.relu,
            Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
            BatchNorm(128),             # 256 parameters, plus 256
          ),
        ),
        Parallel(
          addact(NNlib.relu, ...),
          identity,
          Chain(
            Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
            BatchNorm(128),             # 256 parameters, plus 256
            NNlib.relu,
            Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
            BatchNorm(128),             # 256 parameters, plus 256
          ),
        ),
      ),
      Chain(
        Parallel(
          addact(NNlib.relu, ...),
          Chain(
            Conv((1, 1), 128 => 256, stride=2, bias=false),  # 32_768 parameters
            BatchNorm(256),             # 512 parameters, plus 512
          ),
          Chain(
            Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false),  # 294_912 parameters
            BatchNorm(256),             # 512 parameters, plus 512
            NNlib.relu,
            Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
            BatchNorm(256),             # 512 parameters, plus 512
          ),
        ),
        Parallel(
          addact(NNlib.relu, ...),
          identity,
          Chain(
            Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
            BatchNorm(256),             # 512 parameters, plus 512
            NNlib.relu,
            Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
            BatchNorm(256),             # 512 parameters, plus 512
          ),
        ),
      ),
      Chain(
        Parallel(
          addact(NNlib.relu, ...),
          Chain(
            Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
          ),
          Chain(
            Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false),  # 1_179_648 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
            NNlib.relu,
            Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
          ),
        ),
        Parallel(
          addact(NNlib.relu, ...),
          identity,
          Chain(
            Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
            NNlib.relu,
            Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
          ),
        ),
      ),
    ),
    Chain(
      AdaptiveMeanPool((1, 1)),
      MLUtils.flatten,
      Dense(512 => 1000),               # 513_000 parameters
    ),
  ),
)         # Total: 62 trainable arrays, 11_689_512 parameters,
          # plus 40 non-trainable, 9_600 parameters, summarysize 44.654 MiB.

Using pre-trained models as feature extractors

The backbone and classifier functions do exactly what their names suggest - they are used to extract the backbone and classifier of a model respectively. For example, to extract the backbone of a pre-trained ResNet-18 model:

backbone(model);
Chain(
  Chain(
    Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
    BatchNorm(64, relu),                # 128 parameters, plus 128
    MaxPool((3, 3), pad=1, stride=2),
  ),
  Chain(
    Parallel(
      addact(NNlib.relu, ...),
      identity,
      Chain(
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
        NNlib.relu,
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
      ),
    ),
    Parallel(
      addact(NNlib.relu, ...),
      identity,
      Chain(
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
        NNlib.relu,
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
      ),
    ),
  ),
  Chain(
    Parallel(
      addact(NNlib.relu, ...),
      Chain(
        Conv((1, 1), 64 => 128, stride=2, bias=false),  # 8_192 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
      ),
      Chain(
        Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false),  # 73_728 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
        NNlib.relu,
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
      ),
    ),
    Parallel(
      addact(NNlib.relu, ...),
      identity,
      Chain(
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
        NNlib.relu,
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
      ),
    ),
  ),
  Chain(
    Parallel(
      addact(NNlib.relu, ...),
      Chain(
        Conv((1, 1), 128 => 256, stride=2, bias=false),  # 32_768 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
      Chain(
        Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false),  # 294_912 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
        NNlib.relu,
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
    ),
    Parallel(
      addact(NNlib.relu, ...),
      identity,
      Chain(
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
        NNlib.relu,
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
    ),
  ),
  Chain(
    Parallel(
      addact(NNlib.relu, ...),
      Chain(
        Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      Chain(
        Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false),  # 1_179_648 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
        NNlib.relu,
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
    ),
    Parallel(
      addact(NNlib.relu, ...),
      identity,
      Chain(
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
        NNlib.relu,
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
    ),
  ),
)         # Total: 60 trainable arrays, 11_176_512 parameters,
          # plus 40 non-trainable, 9_600 parameters, summarysize 42.693 MiB.

The backbone function could also be useful for people looking to just use specific sections of the model for transfer learning. The function returns a Chain of the layers of the model, so you can easily index into it to get the layers you want. For example, to get the first five layers of a pre-trained ResNet model, you can just write backbone(model)[1:5].

Training

Now, we can use this model with Flux like any other model. First, let's check the accuracy on a test image from ImageNet.

using Images

# test image
img = Images.load(download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg"));

We'll use the popular DataAugmentation.jl library to crop our input image, convert it to a plain array, and normalize the pixels.

using DataAugmentation
using Flux
using Flux: onecold

DATA_MEAN = (0.485, 0.456, 0.406)
DATA_STD = (0.229, 0.224, 0.225)

augmentations = CenterCrop((224, 224)) |>
                ImageToTensor() |>
                Normalize(DATA_MEAN, DATA_STD)

data = apply(augmentations, Image(img)) |> itemdata

# ImageNet labels
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))

println(onecold(model(Flux.unsqueeze(data, 4)), labels))
["acoustic guitar"]

That is fairly accurate! Below, we train the model on some randomly generated data:

using Optimisers
using Flux: onehotbatch
using Flux.Losses: logitcrossentropy

batchsize = 1
data = [(rand(Float32, 224, 224, 3, batchsize), onehotbatch(rand(1:1000, batchsize), 1:1000))
        for _ in 1:3]
opt = Optimisers.Adam()
state = Optimisers.setup(rule, model);  # initialise this optimiser's state
for (i, (image, y)) in enumerate(data)
    @info "Starting batch $i ..."
    gs, _ = gradient(model, image) do m, x  # calculate the gradients
        logitcrossentropy(m(x), y)
    end;
    state, model = Optimisers.update(state, model, gs);
end