FluxTraining.jl comes with a training loop for standard supervised learning problems, but for different tasks like self-supervised learning, being able to write custom training logic is essential. The package's training loop API requires little boilerplate to convert a regular Flux.jl training loop while making it compatible with existing callbacks.
We'll explore the API step-by-step by converting a basic training loop and then discuss ways in which more complex training loops can be implemented using the same approach. The central piece of a training loop is the logic for a single training step, and in many cases, that will be all you need to implement. Below is the definition of a basic vanilla Flux.jl training step. It takes a batch of data, calculates the loss, gradients and finally updates the parameters of the model.
function
step!
(
model
,
batch
,
params
,
optimizer
,
lossfn
)
xs
,
ys
=
batch
grads
=
gradient
(
params
)
do
ŷs
=
model
(
xs
)
loss
=
lossfn
(
ŷs
,
ys
)
return
loss
end
update!
(
optimizer
,
params
,
grads
)
endTo make a training step work with FluxTraining.jl and its callbacks, we need to
store data for a step so that callbacks can access it (e.g.
Metrics uses
ys and
ŷs to evaluate metrics for each step); and
dispatch events so the callbacks are triggered
We first need to create a
Phase
and implement a method for
FluxTraining.step! that dispatches on the phase type.
Phases are used to define different training behaviors using the same API and to define callback functionality that is only run during certain phases. For example,
Scheduler only runs during
AbstractTrainingPhase
s but not during
ValidationPhase. Let's implement such a phase and method, moving the arguments inside a
Learner in the process.
struct MyTrainingPhase <: FluxTraining.AbstractTrainingPhase
function FluxTraining.step!(learner, phase::MyTrainingPhase, batch)
xs, ys = batch
grads = gradient(learner.params) do
ŷs = learner.model(xs)
loss = learner.lossfn(ŷs, ys)
return loss
end
update!(learner.optimizer, learner.params, grads)
end
Now we can already train a model using this implementation, for example using
epoch!
(learner, MyTrainingPhase(), dataiter). However, no callbacks would be called, since we haven't yet put in any logic that dispatches events or stores the step state. We can do both by using the helper function
runstep which takes care of runnning our step logic, dispatching a
StepBegin
and
StepEnd
event before and after and handling control flow exceptions like
CancelStepException. Additionally,
runstep gives us a function
handle which we can use to dispatch events inside the step, and
state a container for storing step state. Let's use
runstep and store the variables of interest inside
state:
function
step!
(
learner
,
phase
::
MyTrainingPhase
,
batch
)
xs
,
ys
=
batch
runstep
(
learner
,
phase
,
(
xs
=
xs
,
ys
=
ys
)
)
do
handle
,
state
state
.
grads
=
gradient
(
learner
.
params
)
do
state
.
ŷs
=
learner
.
model
(
state
.
xs
)
state
.
loss
=
learner
.
lossfn
(
state
.
ŷs
,
state
.
ys
)
return
loss
end
update!
(
learner
.
optimizer
,
learner
.
params
,
grads
)
end
end
Now callbacks like
Metrics can access variables like
ys through
learner.step (which is set to the last
state). Finally, we can use
handle to dispatch additional events:
using
FluxTraining
.
Events
:
LossBegin
,
BackwardBegin
,
BackwardEnd
function
step!
(
learner
,
phase
::
MyTrainingPhase
,
batch
)
xs
,
ys
=
batch
runstep
(
learner
,
phase
,
(
xs
=
xs
,
ys
=
ys
)
)
do
handle
,
state
state
.
grads
=
gradient
(
learner
.
params
)
do
state
.
ŷs
=
learner
.
model
(
state
.
xs
)
handle
(
LossBegin
(
)
)
state
.
loss
=
learner
.
lossfn
(
state
.
ŷs
,
state
.
ys
)
handle
(
BackwardBegin
(
)
)
return
loss
end
handle
(
BackwardEnd
(
)
)
update!
(
learner
.
optimizer
,
learner
.
params
,
grads
)
end
end
The result is the full implementation of FluxTraining.jl's own
TrainingPhase! Now we can use
epoch! to train a
Learner with full support for
all callbacks:
for
i
in
1
:
10
epoch!
(
learner
,
MyTrainingPhase
(
)
,
dataiter
)
end
The implementation of
ValidationPhase is even simpler; it runs the forward pass and stores variables so that callbacks like
Metrics can access them.
struct
ValidationPhase
<:
AbstractValidationPhase
end
function
step!
(
learner
,
phase
::
ValidationPhase
,
batch
)
xs
,
ys
=
batch
runstep
(
learner
,
phase
,
(
xs
=
xs
,
ys
=
ys
)
)
do
_
,
state
state
.
ŷs
=
learner
.
model
(
state
.
xs
)
state
.
loss
=
learner
.
lossfn
(
state
.
ŷs
,
state
.
ys
)
end
end
We didn't need to implement a custom
epoch! method for our phase since the default is fine here: it just iterates over every batch and calls
step!. In fact, let's have a look at how
epoch! is implemented:
function
epoch!
(
learner
,
phase
::
Phase
,
dataiter
)
runepoch
(
learner
,
phase
)
do
handle
for
batch
in
dataiter
step!
(
learner
,
phase
,
batch
)
end
end
end
Here,
runepoch, similarly to
runstep, takes care of epoch start/stop events and control flow. If you want more control over your training loop, you can use it to write training loops that directly use
step!:
phase
=
MyTrainingPhase
(
)
withepoch
(
learner
,
phase
)
do
handle
for
batch
in
dataiter
step!
(
learner
,
phase
,
batch
)
if
learner
.
step
.
loss
<
0.1
throw
(
CancelFittingException
(
"
Low loss reached.
"
)
)
end
end
endHere are some additional tips for making it easier to implement complicated training loops.
You can pass (named) tuples of models to the
Learner constructor. For example, for generative adversarial training, you can pass in
(generator = ..., critic = ...) and then refer to them inside the
step! implementation, e.g. using
learner.model.generator. The models' parameters will have the same structure, i.e.
learner.params.generator corresponds to
params(learner.model.generator).
You can store any data you want in
state.
When defining a custom phase, instead of subtyping
Phase you can subtype
AbstractTrainingPhase
or
AbstractValidationPhase
so that some context-specific callbacks will work out of the box with your phase type. For example,
Scheduler sets hyperparameter values only during
AbstractTrainingPhase.
The following pages link back here:
Custom callbacks, Data iterators, Models, News, Training loop API reference