Using MLJ to classifiy the MNIST image dataset
This tutorial is available as a Jupyter notebook or julia script here.
Julia version is assumed to be 1.10.*
using MLJ
using Flux
import MLJFlux
import MLUtils
import MLJIteration # for `skip`
If running on a GPU, you will also need to import CUDA
and import cuDNN
.
using Plots
gr(size=(600, 300*(sqrt(5)-1)));
Basic training
Downloading the MNIST image dataset:
import MLDatasets: MNIST
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
images, labels = MNIST(split=:train)[:];
In MLJ, integers cannot be used for encoding categorical data, so we must force the labels to have the Multiclass
scientific type. For more on this, see Working with Categorical Data.
labels = coerce(labels, Multiclass);
images = coerce(images, GrayImage);
Checking scientific types:
@assert scitype(images) <: AbstractVector{<:Image}
@assert scitype(labels) <: AbstractVector{<:Finite}
Looks good.
For general instructions on coercing image data, see Type coercion for image data
images[1]
We start by defining a suitable Builder
object. This is a recipe for building the neural network. Our builder will work for images of any (constant) size, whether they be color or black and white (ie, single or multi-channel). The architecture always consists of six alternating convolution and max-pool layers, and a final dense layer; the filter size and the number of channels after each convolution layer is customisable.
import MLJFlux
struct MyConvBuilder
filter_size::Int
channels1::Int
channels2::Int
channels3::Int
end
function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels)
k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3
mod(k, 2) == 1 || error("`filter_size` must be odd. ")
p = div(k - 1, 2) # padding to preserve image size
init = Flux.glorot_uniform(rng)
front = Chain(
Conv((k, k), n_channels => c1, pad=(p, p), relu, init=init),
MaxPool((2, 2)),
Conv((k, k), c1 => c2, pad=(p, p), relu, init=init),
MaxPool((2, 2)),
Conv((k, k), c2 => c3, pad=(p, p), relu, init=init),
MaxPool((2 ,2)),
MLUtils.flatten)
d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
return Chain(front, Dense(d, n_out, init=init))
end
Notes.
There is no final
softmax
here, as this is applied by default in all MLJFLux classifiers. Customisation of this behaviour is controlled using using thefinaliser
hyperparameter of the classifier.Instead of calculating the padding
p
, Flux can infer the required padding in each dimension, which you enable by replacingpad = (p, p)
withpad = SamePad()
.
We now define the MLJ model.
ImageClassifier = @load ImageClassifier
clf = ImageClassifier(
builder=MyConvBuilder(3, 16, 32, 32),
batch_size=50,
epochs=10,
rng=123,
)
ImageClassifier(
builder = Main.MyConvBuilder(3, 16, 32, 32),
finaliser = NNlib.softmax,
optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8),
loss = Flux.Losses.crossentropy,
epochs = 10,
batch_size = 50,
lambda = 0.0,
alpha = 0.0,
rng = 123,
optimiser_changes_trigger_retraining = false,
acceleration = CPU1{Nothing}(nothing))
You can add Flux options optimiser=...
and loss=...
in the above constructor call. At present, loss
must be a Flux-compatible loss, not an MLJ measure. To run on a GPU, add to the constructor acceleration=CUDALib()
and omit rng
.
For illustration purposes, we won't use all the data here:
train = 1:500
test = 501:1000
501:1000
Binding the model with data in an MLJ machine:
mach = machine(clf, images, labels);
Training for 10 epochs on the first 500 images:
fit!(mach, rows=train, verbosity=2);
[ Info: Training machine(ImageClassifier(builder = Main.MyConvBuilder(3, 16, 32, 32), …), …).
[ Info: Loss is 2.28
[ Info: Loss is 2.171
[ Info: Loss is 1.942
[ Info: Loss is 1.505
[ Info: Loss is 0.9922
[ Info: Loss is 0.6912
[ Info: Loss is 0.5584
[ Info: Loss is 0.4542
[ Info: Loss is 0.3809
[ Info: Loss is 0.3272
Inspecting:
report(mach)
(training_losses = Float32[2.3174262, 2.280439, 2.1711705, 1.9420795, 1.5045885, 0.99224484, 0.69117606, 0.5583703, 0.45424515, 0.38085267, 0.3271538],)
chain = fitted_params(mach)
(chain = Chain(Chain(Chain(Conv((3, 3), 1 => 16, relu, pad=1), MaxPool((2, 2)), Conv((3, 3), 16 => 32, relu, pad=1), MaxPool((2, 2)), Conv((3, 3), 32 => 32, relu, pad=1), MaxPool((2, 2)), flatten), Dense(288 => 10)), softmax),)
Flux.params(chain)[2]
16-element Vector{Float32}:
0.003225543
0.019304937
0.062040687
0.024518687
0.05317823
0.069572166
0.044410173
0.024950704
0.015806748
0.015081032
0.017513964
0.02133927
0.040562775
0.0018777152
0.055122323
0.057923194
Adding 20 more epochs:
clf.epochs = clf.epochs + 20
fit!(mach, rows=train);
[ Info: Updating machine(ImageClassifier(builder = Main.MyConvBuilder(3, 16, 32, 32), …), …).
Optimising neural net: 10%[==> ] ETA: 0:00:08
Optimising neural net: 14%[===> ] ETA: 0:00:11
Optimising neural net: 19%[====> ] ETA: 0:00:11
Optimising neural net: 24%[=====> ] ETA: 0:00:11
Optimising neural net: 29%[=======> ] ETA: 0:00:11
Optimising neural net: 33%[========> ] ETA: 0:00:10
Optimising neural net: 38%[=========> ] ETA: 0:00:09
Optimising neural net: 43%[==========> ] ETA: 0:00:08
Optimising neural net: 48%[===========> ] ETA: 0:00:08
Optimising neural net: 52%[=============> ] ETA: 0:00:07
Optimising neural net: 57%[==============> ] ETA: 0:00:06
Optimising neural net: 62%[===============> ] ETA: 0:00:06
Optimising neural net: 67%[================> ] ETA: 0:00:05
Optimising neural net: 71%[=================> ] ETA: 0:00:04
Optimising neural net: 76%[===================> ] ETA: 0:00:03
Optimising neural net: 81%[====================> ] ETA: 0:00:03
Optimising neural net: 86%[=====================> ] ETA: 0:00:02
Optimising neural net: 90%[======================> ] ETA: 0:00:01
Optimising neural net: 95%[=======================> ] ETA: 0:00:01
Optimising neural net: 100%[=========================] Time: 0:00:14
Computing an out-of-sample estimate of the loss:
predicted_labels = predict(mach, rows=test);
cross_entropy(predicted_labels, labels[test])
0.4883231265583621
Or to fit and predict, in one line:
evaluate!(mach,
resampling=Holdout(fraction_train=0.5),
measure=cross_entropy,
rows=1:1000,
verbosity=0)
PerformanceEvaluation object with these fields:
model, measure, operation,
measurement, per_fold, per_observation,
fitted_params_per_fold, report_per_fold,
train_test_rows, resampling, repeats
Extract:
┌──────────────────────┬───────────┬─────────────┐
│ measure │ operation │ measurement │
├──────────────────────┼───────────┼─────────────┤
│ LogLoss( │ predict │ 0.488 │
│ tol = 2.22045e-16) │ │ │
└──────────────────────┴───────────┴─────────────┘
Wrapping the MLJFlux model with iteration controls
Any iterative MLJFlux model can be wrapped in iteration controls, as we demonstrate next. For more on MLJ's IteratedModel
wrapper, see the MLJ documentation.
The "self-iterating" classifier, called iterated_clf
below, is for iterating the image classifier defined above until one of the following stopping criterion apply:
Patience(3)
: 3 consecutive increases in the lossInvalidValue()
: an out-of-sample loss, or a training loss, isNaN
,Inf
, or-Inf
TimeLimit(t=5/60)
: training time has exceeded 5 minutes
These checks (and other controls) will be applied every two epochs (because of the Step(2)
control). Additionally, training a machine bound to iterated_clf
will:
- save a snapshot of the machine every three control cycles (every six epochs)
- record traces of the out-of-sample loss and training losses for plotting
- record mean value traces of each Flux parameter for plotting
For a complete list of controls, see this table.
Wrapping the classifier
Some helpers
To extract Flux params from an MLJFlux machine
parameters(mach) = vec.(Flux.params(fitted_params(mach)));
To store the traces:
losses = []
training_losses = []
parameter_means = Float32[];
epochs = []
Any[]
To update the traces:
update_loss(loss) = push!(losses, loss)
update_training_loss(losses) = push!(training_losses, losses[end])
update_means(mach) = append!(parameter_means, mean.(parameters(mach)));
update_epochs(epoch) = push!(epochs, epoch)
update_epochs (generic function with 1 method)
The controls to apply:
save_control =
MLJIteration.skip(Save(joinpath(tempdir(), "mnist.jls")), predicate=3)
controls=[
Step(2),
Patience(3),
InvalidValue(),
TimeLimit(5/60),
save_control,
WithLossDo(),
WithLossDo(update_loss),
WithTrainingLossesDo(update_training_loss),
Callback(update_means),
WithIterationsDo(update_epochs),
];
The "self-iterating" classifier:
iterated_clf = IteratedModel(
clf,
controls=controls,
resampling=Holdout(fraction_train=0.7),
measure=log_loss,
)
ProbabilisticIteratedModel(
model = ImageClassifier(
builder = Main.MyConvBuilder(3, 16, 32, 32),
finaliser = NNlib.softmax,
optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8),
loss = Flux.Losses.crossentropy,
epochs = 30,
batch_size = 50,
lambda = 0.0,
alpha = 0.0,
rng = 123,
optimiser_changes_trigger_retraining = false,
acceleration = CPU1{Nothing}(nothing)),
controls = Any[IterationControl.Step(2), EarlyStopping.Patience(3), EarlyStopping.InvalidValue(), EarlyStopping.TimeLimit(Dates.Millisecond(300000)), IterationControl.Skip{MLJIteration.Save{typeof(Serialization.serialize)}, IterationControl.var"#8#9"{Int64}}(MLJIteration.Save{typeof(Serialization.serialize)}("/tmp/mnist.jls", Serialization.serialize), IterationControl.var"#8#9"{Int64}(3)), IterationControl.WithLossDo{IterationControl.var"#20#22"}(IterationControl.var"#20#22"(), false, nothing), IterationControl.WithLossDo{typeof(Main.update_loss)}(Main.update_loss, false, nothing), IterationControl.WithTrainingLossesDo{typeof(Main.update_training_loss)}(Main.update_training_loss, false, nothing), IterationControl.Callback{typeof(Main.update_means)}(Main.update_means, false, nothing, false), MLJIteration.WithIterationsDo{typeof(Main.update_epochs)}(Main.update_epochs, false, nothing)],
resampling = Holdout(
fraction_train = 0.7,
shuffle = false,
rng = Random._GLOBAL_RNG()),
measure = LogLoss(tol = 2.22045e-16),
weights = nothing,
class_weights = nothing,
operation = nothing,
retrain = false,
check_measure = true,
iteration_parameter = nothing,
cache = true)
Binding the wrapped model to data:
mach = machine(iterated_clf, images, labels);
Training
fit!(mach, rows=train);
[ Info: Training machine(ProbabilisticIteratedModel(model = ImageClassifier(builder = Main.MyConvBuilder(3, 16, 32, 32), …), …), …).
[ Info: No iteration parameter specified. Using `iteration_parameter=:(epochs)`.
[ Info: loss: 2.195050130190149
[ Info: loss: 1.8450074691283658
[ Info: Saving "/tmp/mnist1.jls".
[ Info: loss: 1.1388123685158849
[ Info: loss: 0.702997545486733
[ Info: loss: 0.5778269559910739
[ Info: Saving "/tmp/mnist2.jls".
[ Info: loss: 0.5222495075757826
[ Info: loss: 0.49847208228951995
[ Info: loss: 0.4897800580510804
[ Info: Saving "/tmp/mnist3.jls".
[ Info: loss: 0.4893840844808948
[ Info: loss: 0.49094569068535143
[ Info: loss: 0.49593260647952264
[ Info: Saving "/tmp/mnist4.jls".
[ Info: loss: 0.5062357308150314
[ Info: final loss: 0.5062357308150314
[ Info: final training loss: 0.059303638
[ Info: Stop triggered by EarlyStopping.Patience(3) stopping criterion.
[ Info: Total of 24 iterations.
Comparison of the training and out-of-sample losses:
plot(
epochs,
losses,
xlab = "epoch",
ylab = "cross entropy",
label="out-of-sample",
)
plot!(epochs, training_losses, label="training")
savefig(joinpath(tempdir(), "loss.png"))
"/tmp/loss.png"
Evolution of weights
n_epochs = length(losses)
n_parameters = div(length(parameter_means), n_epochs)
parameter_means2 = reshape(copy(parameter_means), n_parameters, n_epochs)'
plot(
epochs,
parameter_means2,
title="Flux parameter mean weights",
xlab = "epoch",
)
Note. The higher the number in the plot legend, the deeper the layer we are **weight-averaging.
savefig(joinpath(tempdir(), "weights.png"))
"/tmp/weights.png"
Retrieving a snapshot for a prediction:
mach2 = machine(joinpath(tempdir(), "mnist3.jls"))
predict_mode(mach2, images[501:503])
3-element CategoricalArrays.CategoricalArray{Int64,1,UInt32}:
7
9
5
Restarting training
Mutating iterated_clf.controls
or clf.epochs
(which is otherwise ignored) will allow you to restart training from where it left off.
iterated_clf.controls[2] = Patience(4)
fit!(mach, rows=train)
plot(
epochs,
losses,
xlab = "epoch",
ylab = "cross entropy",
label="out-of-sample",
)
plot!(epochs, training_losses, label="training")
This page was generated using Literate.jl.