Live Training with MLJFlux
This demonstration is available as a Jupyter notebook or julia script here.
This script tested using Julia 1.10
Basic Imports
using MLJ
using Flux
import Optimisers
using StableRNGs # for reproducibility across Julia versions
stable_rng() = StableRNGs.StableRNG(123)stable_rng (generic function with 1 method)using PlotsLoading and Splitting the Data
iris = load_iris() # a named-tuple of vectors
y, X = unpack(iris, ==(:target), rng=stable_rng())
X = fmap(column-> Float32.(column), X) # Flux prefers Float32 data(sepal_length = Float32[6.1, 7.3, 6.3, 4.8, 5.9, 7.1, 6.7, 5.4, 6.0, 6.9 … 5.0, 6.4, 5.7, 4.6, 5.5, 4.6, 5.6, 5.7, 6.0, 5.0], sepal_width = Float32[2.9, 2.9, 3.4, 3.4, 3.0, 3.0, 3.0, 3.9, 3.0, 3.1 … 3.3, 2.7, 2.5, 3.2, 2.4, 3.1, 2.8, 3.0, 2.9, 3.5], petal_length = Float32[4.7, 6.3, 5.6, 1.9, 5.1, 5.9, 5.0, 1.7, 4.8, 4.9 … 1.4, 5.3, 5.0, 1.4, 3.7, 1.5, 4.9, 4.2, 4.5, 1.6], petal_width = Float32[1.4, 1.8, 2.4, 0.2, 1.8, 2.1, 1.7, 0.4, 1.8, 1.5 … 0.2, 1.9, 2.0, 0.2, 1.0, 0.2, 2.0, 1.2, 1.5, 0.6])Instantiating the model
Now let's construct our model. This follows a similar setup to the one followed in the Quick Start.
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
clf = NeuralNetworkClassifier(
builder=MLJFlux.MLP(; hidden=(5,4), σ=Flux.relu),
optimiser=Optimisers.Adam(0.01),
batch_size=8,
epochs=50,
rng=stable_rng(),
)NeuralNetworkClassifier(
builder = MLP(
hidden = (5, 4),
σ = NNlib.relu),
finaliser = NNlib.softmax,
optimiser = Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8),
loss = Flux.Losses.crossentropy,
epochs = 50,
batch_size = 8,
lambda = 0.0,
alpha = 0.0,
rng = StableRNGs.LehmerRNG(state=0x000000000000000000000000000000f7),
optimiser_changes_trigger_retraining = false,
acceleration = CPU1{Nothing}(nothing),
embedding_dims = Dict{Symbol, Real}())Now let's wrap this in an iterated model. We will use a callback that makes a plot for validation losses each iteration.
stop_conditions = [
Step(1), # Repeatedly train for one iteration
NumberLimit(100), # Don't train for more than 100 iterations
]
validation_losses = []
gr(reuse=true) # use the same window for plots
function plot_loss(loss)
push!(validation_losses, loss)
display(plot(validation_losses, label="validation loss", xlim=(1, 100)))
sleep(.01) # to catch up with the plots while they are being generated
end
callbacks = [ WithLossDo(plot_loss),]
iterated_model = IteratedModel(
model=clf,
resampling=Holdout(),
measures=log_loss,
iteration_parameter=:(epochs),
controls=vcat(stop_conditions, callbacks),
retrain=true,
)ProbabilisticIteratedModel(
model = NeuralNetworkClassifier(
builder = MLP(hidden = (5, 4), …),
finaliser = NNlib.softmax,
optimiser = Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8),
loss = Flux.Losses.crossentropy,
epochs = 50,
batch_size = 8,
lambda = 0.0,
alpha = 0.0,
rng = StableRNGs.LehmerRNG(state=0x000000000000000000000000000000f7),
optimiser_changes_trigger_retraining = false,
acceleration = CPU1{Nothing}(nothing),
embedding_dims = Dict{Symbol, Real}()),
controls = Any[IterationControl.Step(1), EarlyStopping.NumberLimit(100), IterationControl.WithLossDo{typeof(Main.plot_loss)}(Main.plot_loss, false, nothing)],
resampling = Holdout(
fraction_train = 0.7,
shuffle = false,
rng = Random._GLOBAL_RNG()),
measure = LogLoss(tol = 2.22045e-16),
weights = nothing,
class_weights = nothing,
operation = nothing,
retrain = true,
check_measure = true,
iteration_parameter = :epochs,
cache = true)Live Training
Simply fitting the model is all we need
mach = machine(iterated_model, X, y)
fit!(mach)
validation_losses100-element Vector{Any}:
0.8352088626868427
0.6732681053927665
0.5739031795541483
0.5043200472161209
0.4575995869258531
0.41717346219584134
0.38210729479140026
0.34488092438950246
0.31056650492452953
0.274387875506278
⋮
0.025626456909644407
0.025847989322549007
0.025857928972110953
0.02571874249807848
0.025520618737545187
0.02535433864076781
0.025166694645737318
0.024993059703878436
0.024808875413973978Note that the wrapped model sets aside some data on which to make out-of-sample estimates of the loss, which is how validation_losses are calculated. But if we use mach to make predictions on new input features, these are based on retraining the model on all provided data.
Xnew = (
sepal_length = Float32[5.8, 5.8, 5.8],
sepal_width = Float32[4.0, 2.6, 2.7],
petal_length = Float32[1.2, 4.0, 4.1],
petal_width = Float32[0.2, 1.2, 1.0],
)
predict_mode(mach, Xnew)3-element CategoricalArrays.CategoricalArray{String,1,UInt32}:
"setosa"
"versicolor"
"versicolor"This page was generated using Literate.jl.