Using MLJ to classifiy the MNIST image dataset

This tutorial is available as a Jupyter notebook or julia script here.

Julia version is assumed to be 1.10.*

using MLJ
using Flux
import MLJFlux
import MLUtils
import MLJIteration # for `skip`

If running on a GPU, you will also need to import CUDA and import cuDNN.

using Plots
gr(size=(600, 300*(sqrt(5)-1)));

Basic training

Downloading the MNIST image dataset:

import MLDatasets: MNIST

ENV["DATADEPS_ALWAYS_ACCEPT"] = true
images, labels = MNIST(split=:train)[:];

In MLJ, integers cannot be used for encoding categorical data, so we must force the labels to have the Multiclass scientific type. For more on this, see Working with Categorical Data.

labels = coerce(labels, Multiclass);
images = coerce(images, GrayImage);

Checking scientific types:

@assert scitype(images) <: AbstractVector{<:Image}
@assert scitype(labels) <: AbstractVector{<:Finite}

Looks good.

For general instructions on coercing image data, see Type coercion for image data

images[1]
Example block output

We start by defining a suitable Builder object. This is a recipe for building the neural network. Our builder will work for images of any (constant) size, whether they be color or black and white (ie, single or multi-channel). The architecture always consists of six alternating convolution and max-pool layers, and a final dense layer; the filter size and the number of channels after each convolution layer is customisable.

import MLJFlux
struct MyConvBuilder
    filter_size::Int
    channels1::Int
    channels2::Int
    channels3::Int
end

function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels)
    k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3
    mod(k, 2) == 1 || error("`filter_size` must be odd. ")
    p = div(k - 1, 2) # padding to preserve image size
    init = Flux.glorot_uniform(rng)
    front = Chain(
        Conv((k, k), n_channels => c1, pad=(p, p), relu, init=init),
        MaxPool((2, 2)),
        Conv((k, k), c1 => c2, pad=(p, p), relu, init=init),
        MaxPool((2, 2)),
        Conv((k, k), c2 => c3, pad=(p, p), relu, init=init),
        MaxPool((2 ,2)),
        MLUtils.flatten)
    d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
    return Chain(front, Dense(d, n_out, init=init))
end

Notes.

  • There is no final softmax here, as this is applied by default in all MLJFLux classifiers. Customisation of this behaviour is controlled using using the finaliser hyperparameter of the classifier.

  • Instead of calculating the padding p, Flux can infer the required padding in each dimension, which you enable by replacing pad = (p, p) with pad = SamePad().

We now define the MLJ model.

ImageClassifier = @load ImageClassifier
clf = ImageClassifier(
    builder=MyConvBuilder(3, 16, 32, 32),
    batch_size=50,
    epochs=10,
    rng=123,
)
ImageClassifier(
  builder = Main.MyConvBuilder(3, 16, 32, 32), 
  finaliser = NNlib.softmax, 
  optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), 
  loss = Flux.Losses.crossentropy, 
  epochs = 10, 
  batch_size = 50, 
  lambda = 0.0, 
  alpha = 0.0, 
  rng = 123, 
  optimiser_changes_trigger_retraining = false, 
  acceleration = CPU1{Nothing}(nothing))

You can add Flux options optimiser=... and loss=... in the above constructor call. At present, loss must be a Flux-compatible loss, not an MLJ measure. To run on a GPU, add to the constructor acceleration=CUDALib() and omit rng.

For illustration purposes, we won't use all the data here:

train = 1:500
test = 501:1000
501:1000

Binding the model with data in an MLJ machine:

mach = machine(clf, images, labels);

Training for 10 epochs on the first 500 images:

fit!(mach, rows=train, verbosity=2);
[ Info: Training machine(ImageClassifier(builder = Main.MyConvBuilder(3, 16, 32, 32), …), …).
[ Info: Loss is 2.28
[ Info: Loss is 2.171
[ Info: Loss is 1.942
[ Info: Loss is 1.505
[ Info: Loss is 0.9922
[ Info: Loss is 0.6912
[ Info: Loss is 0.5584
[ Info: Loss is 0.4542
[ Info: Loss is 0.3809
[ Info: Loss is 0.3272

Inspecting:

report(mach)
(training_losses = Float32[2.3174262, 2.280439, 2.1711705, 1.9420795, 1.5045885, 0.99224484, 0.69117606, 0.5583703, 0.45424515, 0.38085267, 0.3271538],)
chain = fitted_params(mach)
(chain = Chain(Chain(Chain(Conv((3, 3), 1 => 16, relu, pad=1), MaxPool((2, 2)), Conv((3, 3), 16 => 32, relu, pad=1), MaxPool((2, 2)), Conv((3, 3), 32 => 32, relu, pad=1), MaxPool((2, 2)), flatten), Dense(288 => 10)), softmax),)
Flux.params(chain)[2]
16-element Vector{Float32}:
 0.003225543
 0.019304937
 0.062040687
 0.024518687
 0.05317823
 0.069572166
 0.044410173
 0.024950704
 0.015806748
 0.015081032
 0.017513964
 0.02133927
 0.040562775
 0.0018777152
 0.055122323
 0.057923194

Adding 20 more epochs:

clf.epochs = clf.epochs + 20
fit!(mach, rows=train);
[ Info: Updating machine(ImageClassifier(builder = Main.MyConvBuilder(3, 16, 32, 32), …), …).

Optimising neural net:  10%[==>                      ]  ETA: 0:00:08
Optimising neural net:  14%[===>                     ]  ETA: 0:00:11
Optimising neural net:  19%[====>                    ]  ETA: 0:00:11
Optimising neural net:  24%[=====>                   ]  ETA: 0:00:11
Optimising neural net:  29%[=======>                 ]  ETA: 0:00:11
Optimising neural net:  33%[========>                ]  ETA: 0:00:10
Optimising neural net:  38%[=========>               ]  ETA: 0:00:09
Optimising neural net:  43%[==========>              ]  ETA: 0:00:08
Optimising neural net:  48%[===========>             ]  ETA: 0:00:08
Optimising neural net:  52%[=============>           ]  ETA: 0:00:07
Optimising neural net:  57%[==============>          ]  ETA: 0:00:06
Optimising neural net:  62%[===============>         ]  ETA: 0:00:06
Optimising neural net:  67%[================>        ]  ETA: 0:00:05
Optimising neural net:  71%[=================>       ]  ETA: 0:00:04
Optimising neural net:  76%[===================>     ]  ETA: 0:00:03
Optimising neural net:  81%[====================>    ]  ETA: 0:00:03
Optimising neural net:  86%[=====================>   ]  ETA: 0:00:02
Optimising neural net:  90%[======================>  ]  ETA: 0:00:01
Optimising neural net:  95%[=======================> ]  ETA: 0:00:01
Optimising neural net: 100%[=========================] Time: 0:00:14

Computing an out-of-sample estimate of the loss:

predicted_labels = predict(mach, rows=test);
cross_entropy(predicted_labels, labels[test])
0.4883231265583621

Or to fit and predict, in one line:

evaluate!(mach,
          resampling=Holdout(fraction_train=0.5),
          measure=cross_entropy,
          rows=1:1000,
          verbosity=0)
PerformanceEvaluation object with these fields:
  model, measure, operation,
  measurement, per_fold, per_observation,
  fitted_params_per_fold, report_per_fold,
  train_test_rows, resampling, repeats
Extract:
┌──────────────────────┬───────────┬─────────────┐
│ measure              │ operation │ measurement │
├──────────────────────┼───────────┼─────────────┤
│ LogLoss(             │ predict   │ 0.488       │
│   tol = 2.22045e-16) │           │             │
└──────────────────────┴───────────┴─────────────┘

Wrapping the MLJFlux model with iteration controls

Any iterative MLJFlux model can be wrapped in iteration controls, as we demonstrate next. For more on MLJ's IteratedModel wrapper, see the MLJ documentation.

The "self-iterating" classifier, called iterated_clf below, is for iterating the image classifier defined above until one of the following stopping criterion apply:

  • Patience(3): 3 consecutive increases in the loss
  • InvalidValue(): an out-of-sample loss, or a training loss, is NaN, Inf, or -Inf
  • TimeLimit(t=5/60): training time has exceeded 5 minutes

These checks (and other controls) will be applied every two epochs (because of the Step(2) control). Additionally, training a machine bound to iterated_clf will:

  • save a snapshot of the machine every three control cycles (every six epochs)
  • record traces of the out-of-sample loss and training losses for plotting
  • record mean value traces of each Flux parameter for plotting

For a complete list of controls, see this table.

Wrapping the classifier

Some helpers

To extract Flux params from an MLJFlux machine

parameters(mach) = vec.(Flux.params(fitted_params(mach)));

To store the traces:

losses = []
training_losses = []
parameter_means = Float32[];
epochs = []
Any[]

To update the traces:

update_loss(loss) = push!(losses, loss)
update_training_loss(losses) = push!(training_losses, losses[end])
update_means(mach) = append!(parameter_means, mean.(parameters(mach)));
update_epochs(epoch) = push!(epochs, epoch)
update_epochs (generic function with 1 method)

The controls to apply:

save_control =
    MLJIteration.skip(Save(joinpath(tempdir(), "mnist.jls")), predicate=3)

controls=[
    Step(2),
    Patience(3),
    InvalidValue(),
    TimeLimit(5/60),
    save_control,
    WithLossDo(),
    WithLossDo(update_loss),
    WithTrainingLossesDo(update_training_loss),
    Callback(update_means),
    WithIterationsDo(update_epochs),
];

The "self-iterating" classifier:

iterated_clf = IteratedModel(
    clf,
    controls=controls,
    resampling=Holdout(fraction_train=0.7),
    measure=log_loss,
)
ProbabilisticIteratedModel(
  model = ImageClassifier(
        builder = Main.MyConvBuilder(3, 16, 32, 32), 
        finaliser = NNlib.softmax, 
        optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8), 
        loss = Flux.Losses.crossentropy, 
        epochs = 30, 
        batch_size = 50, 
        lambda = 0.0, 
        alpha = 0.0, 
        rng = 123, 
        optimiser_changes_trigger_retraining = false, 
        acceleration = CPU1{Nothing}(nothing)), 
  controls = Any[IterationControl.Step(2), EarlyStopping.Patience(3), EarlyStopping.InvalidValue(), EarlyStopping.TimeLimit(Dates.Millisecond(300000)), IterationControl.Skip{MLJIteration.Save{typeof(Serialization.serialize)}, IterationControl.var"#8#9"{Int64}}(MLJIteration.Save{typeof(Serialization.serialize)}("/tmp/mnist.jls", Serialization.serialize), IterationControl.var"#8#9"{Int64}(3)), IterationControl.WithLossDo{IterationControl.var"#20#22"}(IterationControl.var"#20#22"(), false, nothing), IterationControl.WithLossDo{typeof(Main.update_loss)}(Main.update_loss, false, nothing), IterationControl.WithTrainingLossesDo{typeof(Main.update_training_loss)}(Main.update_training_loss, false, nothing), IterationControl.Callback{typeof(Main.update_means)}(Main.update_means, false, nothing, false), MLJIteration.WithIterationsDo{typeof(Main.update_epochs)}(Main.update_epochs, false, nothing)], 
  resampling = Holdout(
        fraction_train = 0.7, 
        shuffle = false, 
        rng = Random._GLOBAL_RNG()), 
  measure = LogLoss(tol = 2.22045e-16), 
  weights = nothing, 
  class_weights = nothing, 
  operation = nothing, 
  retrain = false, 
  check_measure = true, 
  iteration_parameter = nothing, 
  cache = true)

Binding the wrapped model to data:

mach = machine(iterated_clf, images, labels);

Training

fit!(mach, rows=train);
[ Info: Training machine(ProbabilisticIteratedModel(model = ImageClassifier(builder = Main.MyConvBuilder(3, 16, 32, 32), …), …), …).
[ Info: No iteration parameter specified. Using `iteration_parameter=:(epochs)`.
[ Info: loss: 2.195050130190149
[ Info: loss: 1.8450074691283658
[ Info: Saving "/tmp/mnist1.jls".
[ Info: loss: 1.1388123685158849
[ Info: loss: 0.702997545486733
[ Info: loss: 0.5778269559910739
[ Info: Saving "/tmp/mnist2.jls".
[ Info: loss: 0.5222495075757826
[ Info: loss: 0.49847208228951995
[ Info: loss: 0.4897800580510804
[ Info: Saving "/tmp/mnist3.jls".
[ Info: loss: 0.4893840844808948
[ Info: loss: 0.49094569068535143
[ Info: loss: 0.49593260647952264
[ Info: Saving "/tmp/mnist4.jls".
[ Info: loss: 0.5062357308150314
[ Info: final loss: 0.5062357308150314
[ Info: final training loss: 0.059303638
[ Info: Stop triggered by EarlyStopping.Patience(3) stopping criterion.
[ Info: Total of 24 iterations.

Comparison of the training and out-of-sample losses:

plot(
    epochs,
    losses,
    xlab = "epoch",
    ylab = "cross entropy",
    label="out-of-sample",
)
plot!(epochs, training_losses, label="training")

savefig(joinpath(tempdir(), "loss.png"))
"/tmp/loss.png"

Evolution of weights

n_epochs =  length(losses)
n_parameters = div(length(parameter_means), n_epochs)
parameter_means2 = reshape(copy(parameter_means), n_parameters, n_epochs)'
plot(
    epochs,
    parameter_means2,
    title="Flux parameter mean weights",
    xlab = "epoch",
)
Example block output

Note. The higher the number in the plot legend, the deeper the layer we are **weight-averaging.

savefig(joinpath(tempdir(), "weights.png"))
"/tmp/weights.png"

Retrieving a snapshot for a prediction:

mach2 = machine(joinpath(tempdir(), "mnist3.jls"))
predict_mode(mach2, images[501:503])
3-element CategoricalArrays.CategoricalArray{Int64,1,UInt32}:
 7
 9
 5

Restarting training

Mutating iterated_clf.controls or clf.epochs (which is otherwise ignored) will allow you to restart training from where it left off.

iterated_clf.controls[2] = Patience(4)
fit!(mach, rows=train)

plot(
    epochs,
    losses,
    xlab = "epoch",
    ylab = "cross entropy",
    label="out-of-sample",
)
plot!(epochs, training_losses, label="training")
Example block output

This page was generated using Literate.jl.