Quickstart
Using a model from Metalhead is as simple as selecting a model from the table of available models. For example, below we use the pre-trained ResNet-18 model.
using Flux, Metalhead
model = ResNet(18; pretrain = true)
ResNet(
Chain(
Chain([
Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false), # 9_408 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
MaxPool((3, 3), pad=1, stride=2),
Parallel(
Metalhead.addrelu,
Chain(
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64), # 128 parameters, plus 128
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64), # 128 parameters, plus 128
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false), # 73_728 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, bias=false), # 147_456 parameters
BatchNorm(128), # 256 parameters, plus 256
),
Chain([
Conv((1, 1), 64 => 128, stride=2, bias=false), # 8_192 parameters
BatchNorm(128), # 256 parameters, plus 256
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((3, 3), 128 => 128, pad=1, bias=false), # 147_456 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, bias=false), # 147_456 parameters
BatchNorm(128), # 256 parameters, plus 256
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false), # 294_912 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256), # 512 parameters, plus 512
),
Chain([
Conv((1, 1), 128 => 256, stride=2, bias=false), # 32_768 parameters
BatchNorm(256), # 512 parameters, plus 512
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256), # 512 parameters, plus 512
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false), # 1_179_648 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1, bias=false), # 2_359_296 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
Chain([
Conv((1, 1), 256 => 512, stride=2, bias=false), # 131_072 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((3, 3), 512 => 512, pad=1, bias=false), # 2_359_296 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1, bias=false), # 2_359_296 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
identity,
),
]),
Chain(
AdaptiveMeanPool((1, 1)),
MLUtils.flatten,
Dense(512 => 1000), # 513_000 parameters
),
),
) # Total: 62 trainable arrays, 11_689_512 parameters,
# plus 40 non-trainable, 9_600 parameters, summarysize 44.642 MiB.
Now, we can use this model with Flux like any other model.
First, let’s check the accuracy on a test image from ImageNet.
using Images
# test image
img = Images.load(download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg"));
\begin{figure}
\centering
\includegraphics[max width=\linewidth]{15415789024014242692.png}
\caption{}
\end{figure}
We’ll use the popular DataAugmentation.jl library to crop our input image, convert it to a plain array, and normalize the pixels.
using DataAugmentation
DATA_MEAN = (0.485, 0.456, 0.406)
DATA_STD = (0.229, 0.224, 0.225)
augmentations = CenterCrop((224, 224)) |>
ImageToTensor() |>
Normalize(DATA_MEAN, DATA_STD)
data = apply(augmentations, Image(img)) |> itemdata
# image net labels
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))
Flux.onecold(model(Flux.unsqueeze(data, 4)), labels)
1-element Vector{String}:
"acoustic guitar"
Below, we train it on some randomly generated data.
using Flux: onehotbatch
batchsize = 1
data = [(rand(Float32, 224, 224, 3, batchsize), onehotbatch(rand(1:1000, batchsize), 1:1000))
for _ in 1:3]
opt = ADAM()
ps = Flux.params(model)
loss(x, y, m) = Flux.Losses.logitcrossentropy(m(x), y)
for (i, (x, y)) in enumerate(data)
@info "Starting batch $i ..."
gs = gradient(() -> loss(x, y, model), ps)
Flux.update!(opt, ps, gs)
end