Live Training with MLJFlux

This demonstration is available as a Jupyter notebook or julia script here.

This script tested using Julia 1.10

Basic Imports

using MLJ
using Flux
import Optimisers
using StableRNGs        # for reproducibility across Julia versions

stable_rng() = StableRNGs.StableRNG(123)
stable_rng (generic function with 1 method)
using Plots

Loading and Splitting the Data

iris = load_iris() # a named-tuple of vectors
y, X = unpack(iris, ==(:target), rng=stable_rng())
X = fmap(column-> Float32.(column), X) # Flux prefers Float32 data
(sepal_length = Float32[6.1, 7.3, 6.3, 4.8, 5.9, 7.1, 6.7, 5.4, 6.0, 6.9  …  5.0, 6.4, 5.7, 4.6, 5.5, 4.6, 5.6, 5.7, 6.0, 5.0], sepal_width = Float32[2.9, 2.9, 3.4, 3.4, 3.0, 3.0, 3.0, 3.9, 3.0, 3.1  …  3.3, 2.7, 2.5, 3.2, 2.4, 3.1, 2.8, 3.0, 2.9, 3.5], petal_length = Float32[4.7, 6.3, 5.6, 1.9, 5.1, 5.9, 5.0, 1.7, 4.8, 4.9  …  1.4, 5.3, 5.0, 1.4, 3.7, 1.5, 4.9, 4.2, 4.5, 1.6], petal_width = Float32[1.4, 1.8, 2.4, 0.2, 1.8, 2.1, 1.7, 0.4, 1.8, 1.5  …  0.2, 1.9, 2.0, 0.2, 1.0, 0.2, 2.0, 1.2, 1.5, 0.6])

Instantiating the model

Now let's construct our model. This follows a similar setup to the one followed in the Quick Start.

NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux

clf = NeuralNetworkClassifier(
    builder=MLJFlux.MLP(; hidden=(5,4), σ=Flux.relu),
    optimiser=Optimisers.Adam(0.01),
    batch_size=8,
    epochs=50,
    rng=stable_rng(),
)
NeuralNetworkClassifier(
  builder = MLP(
        hidden = (5, 4), 
        σ = NNlib.relu), 
  finaliser = NNlib.softmax, 
  optimiser = Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), 
  loss = Flux.Losses.crossentropy, 
  epochs = 50, 
  batch_size = 8, 
  lambda = 0.0, 
  alpha = 0.0, 
  rng = StableRNGs.LehmerRNG(state=0x000000000000000000000000000000f7), 
  optimiser_changes_trigger_retraining = false, 
  acceleration = CPU1{Nothing}(nothing), 
  embedding_dims = Dict{Symbol, Real}())

Now let's wrap this in an iterated model. We will use a callback that makes a plot for validation losses each iteration.

stop_conditions = [
    Step(1),            # Repeatedly train for one iteration
    NumberLimit(100),   # Don't train for more than 100 iterations
]

validation_losses =  []
gr(reuse=true)                  # use the same window for plots
function plot_loss(loss)
    push!(validation_losses, loss)
    display(plot(validation_losses, label="validation loss", xlim=(1, 100)))
    sleep(.01)  # to catch up with the plots while they are being generated
end

callbacks = [ WithLossDo(plot_loss),]

iterated_model = IteratedModel(
    model=clf,
    resampling=Holdout(),
    measures=log_loss,
    iteration_parameter=:(epochs),
    controls=vcat(stop_conditions, callbacks),
    retrain=true,
)
ProbabilisticIteratedModel(
  model = NeuralNetworkClassifier(
        builder = MLP(hidden = (5, 4), …), 
        finaliser = NNlib.softmax, 
        optimiser = Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), 
        loss = Flux.Losses.crossentropy, 
        epochs = 50, 
        batch_size = 8, 
        lambda = 0.0, 
        alpha = 0.0, 
        rng = StableRNGs.LehmerRNG(state=0x000000000000000000000000000000f7), 
        optimiser_changes_trigger_retraining = false, 
        acceleration = CPU1{Nothing}(nothing), 
        embedding_dims = Dict{Symbol, Real}()), 
  controls = Any[IterationControl.Step(1), EarlyStopping.NumberLimit(100), IterationControl.WithLossDo{typeof(Main.plot_loss)}(Main.plot_loss, false, nothing)], 
  resampling = Holdout(
        fraction_train = 0.7, 
        shuffle = false, 
        rng = Random._GLOBAL_RNG()), 
  measure = LogLoss(tol = 2.22045e-16), 
  weights = nothing, 
  class_weights = nothing, 
  operation = nothing, 
  retrain = true, 
  check_measure = true, 
  iteration_parameter = :epochs, 
  cache = true)

Live Training

Simply fitting the model is all we need

mach = machine(iterated_model, X, y)
fit!(mach)
validation_losses
100-element Vector{Any}:
 0.8352088626868427
 0.6732681053927665
 0.5739031795541483
 0.5043200472161209
 0.4575995869258531
 0.41717346219584134
 0.38210729479140026
 0.34488092438950246
 0.31056650492452953
 0.274387875506278
 ⋮
 0.025626456909644407
 0.025847989322549007
 0.025857928972110953
 0.02571874249807848
 0.025520618737545187
 0.02535433864076781
 0.025166694645737318
 0.024993059703878436
 0.024808875413973978

Note that the wrapped model sets aside some data on which to make out-of-sample estimates of the loss, which is how validation_losses are calculated. But if we use mach to make predictions on new input features, these are based on retraining the model on all provided data.

Xnew = (
    sepal_length = Float32[5.8, 5.8, 5.8],
    sepal_width = Float32[4.0, 2.6, 2.7],
    petal_length = Float32[1.2, 4.0, 4.1],
    petal_width = Float32[0.2, 1.2, 1.0],
)

predict_mode(mach, Xnew)
3-element CategoricalArrays.CategoricalArray{String,1,UInt32}:
 "setosa"
 "versicolor"
 "versicolor"

This page was generated using Literate.jl.