FluxTraining.jl comes with a training loop for standard supervised learning problems, but for different tasks like self-supervised learning, being able to write custom training logic is essential. The package’s training loop API requires little boilerplate to convert a regular Flux.jl training loop while making it compatible with existing callbacks.
We’ll explore the API step-by-step by converting a basic training loop and then discuss ways in which more complex training loops can be implemented using the same approach. The central piece of a training loop is the logic for a single training step, and in many cases, that will be all you need to implement. Below is the definition of a basic vanilla Flux.jl training step. It takes a batch of data, calculates the loss, gradients and finally updates the parameters of the model.
function
step!
(
model
,
batch
,
params
,
optimizer
,
lossfn
)
xs
,
ys
=
batch
grads
=
gradient
(
params
)
do
ŷs
=
model
(
xs
)
loss
=
lossfn
(
ŷs
,
ys
)
return
loss
end
update!
(
optimizer
,
params
,
grads
)
end
To make a training step work with FluxTraining.jl and its callbacks, we need to
store data for a step so that callbacks can access it (e.g.
Metrics
uses
ys
and
ŷs
to evaluate metrics for each step); and
dispatch events so the callbacks are triggered
We first need to create a
Phase
and implement a method for
FluxTraining.step!
that dispatches on the phase type.
Phase
s are used to define different training behaviors using the same API and to define callback functionality that is only run during certain phases. For example,
Scheduler
only runs during
AbstractTrainingPhase
s but not during
ValidationPhase
. Let’s implement such a phase and method, moving the arguments inside a
Learner
in the process.
struct MyTrainingPhase <: FluxTraining.AbstractTrainingPhase
function FluxTraining.step!(learner, phase::MyTrainingPhase, batch)
xs, ys = batch
grads = gradient(learner.params) do
ŷs = learner.model(xs)
loss = learner.lossfn(ŷs, ys)
return loss
end
update!(learner.optimizer, learner.params, grads)
end
Now we can already train a model using this implementation, for example using
epoch!
(learner, MyTrainingPhase(), dataiter)
. However, no callbacks would be called, since we haven’t yet put in any logic that dispatches events or stores the step state. We can do both by using the helper function
runstep
which takes care of runnning our step logic, dispatching a
StepBegin
and
StepEnd
event before and after and handling control flow exceptions like
CancelStepException
. Additionally,
runstep
gives us a function
handle
which we can use to dispatch events inside the step, and
state
a container for storing step state. Let’s use
runstep
and store the variables of interest inside
state
:
function
step!
(
learner
,
phase
::
MyTrainingPhase
,
batch
)
xs
,
ys
=
batch
runstep
(
learner
,
phase
,
(
xs
=
xs
,
ys
=
ys
)
)
do
handle
,
state
state
.
grads
=
gradient
(
learner
.
params
)
do
state
.
ŷs
=
learner
.
model
(
state
.
xs
)
state
.
loss
=
learner
.
lossfn
(
state
.
ŷs
,
state
.
ys
)
return
loss
end
update!
(
learner
.
optimizer
,
learner
.
params
,
grads
)
end
end
Now callbacks like
Metrics
can access variables like
ys
through
learner.step
(which is set to the last
state
). Finally, we can use
handle
to dispatch additional events:
using
FluxTraining
.
Events
:
LossBegin
,
BackwardBegin
,
BackwardEnd
function
step!
(
learner
,
phase
::
MyTrainingPhase
,
batch
)
xs
,
ys
=
batch
runstep
(
learner
,
phase
,
(
xs
=
xs
,
ys
=
ys
)
)
do
handle
,
state
state
.
grads
=
gradient
(
learner
.
params
)
do
state
.
ŷs
=
learner
.
model
(
state
.
xs
)
handle
(
LossBegin
(
)
)
state
.
loss
=
learner
.
lossfn
(
state
.
ŷs
,
state
.
ys
)
handle
(
BackwardBegin
(
)
)
return
loss
end
handle
(
BackwardEnd
(
)
)
update!
(
learner
.
optimizer
,
learner
.
params
,
grads
)
end
end
The result is the full implementation of FluxTraining.jl’s own
TrainingPhase
! Now we can use
epoch!
to train a
Learner
with full support for
all callbacks
:
for
i
in
1
:
10
epoch!
(
learner
,
MyTrainingPhase
(
)
,
dataiter
)
end
The implementation of
ValidationPhase
is even simpler; it runs the forward pass and stores variables so that callbacks like
Metrics
can access them.
struct
ValidationPhase
<:
AbstractValidationPhase
end
function
step!
(
learner
,
phase
::
ValidationPhase
,
batch
)
xs
,
ys
=
batch
runstep
(
learner
,
phase
,
(
xs
=
xs
,
ys
=
ys
)
)
do
_
,
state
state
.
ŷs
=
learner
.
model
(
state
.
xs
)
state
.
loss
=
learner
.
lossfn
(
state
.
ŷs
,
state
.
ys
)
end
end
We didn’t need to implement a custom
epoch!
method for our phase since the default is fine here: it just iterates over every batch and calls
step!
. In fact, let’s have a look at how
epoch!
is implemented:
function
epoch!
(
learner
,
phase
::
Phase
,
dataiter
)
runepoch
(
learner
,
phase
)
do
handle
for
batch
in
dataiter
step!
(
learner
,
phase
,
batch
)
end
end
end
Here,
runepoch
, similarly to
runstep
, takes care of epoch start/stop events and control flow. If you want more control over your training loop, you can use it to write training loops that directly use
step!
:
phase
=
MyTrainingPhase
(
)
withepoch
(
learner
,
phase
)
do
handle
for
batch
in
dataiter
step!
(
learner
,
phase
,
batch
)
if
learner
.
step
.
loss
<
0.1
throw
(
CancelFittingException
(
"
Low loss reached.
"
)
)
end
end
end
Here are some additional tips for making it easier to implement complicated training loops.
You can pass (named) tuples of models to the
Learner
constructor. For example, for generative adversarial training, you can pass in
(generator = ..., critic = ...)
and then refer to them inside the
step!
implementation, e.g. using
learner.model.generator
. The models’ parameters will have the same structure, i.e.
learner.params.generator
corresponds to
params(learner.model.generator)
.
You can store any data you want in
state
.
When defining a custom phase, instead of subtyping
Phase
you can subtype
AbstractTrainingPhase
or
AbstractValidationPhase
so that some context-specific callbacks will work out of the box with your phase type. For example,
Scheduler
sets hyperparameter values only during
AbstractTrainingPhase
.
The following pages link back here:
Custom callbacks, Data iterators, FluxTraining.jl, Models, News, Training loop API reference