train!
function
defined in module
Flux.Optimise
train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
Uses a
loss
function and training
data
to improve the model's parameters according to a particular optimisation rule
opt
.
This method with implicit
Params
will be removed from Flux 0.14. It should be replaced with the explicit method
train!(loss, model, data, opt)
.
For each
d in data
, first the gradient of the
loss
is computed like this:
gradient(() -> loss(d...), pars) # if d isa Tuple
gradient(() -> loss(d), pars) # otherwise
Here
pars
is produced by calling
Flux.params
on your model. (Or just on the layers you want to train, like
train!(loss, params(model[1:end-2]), data, opt)
.) This is the "implicit" style of parameter handling.
This gradient is then used by optimiser
opt
to update the parameters:
update!(opt, pars, grads)
The optimiser should be from the
Flux.Optimise
module (see
Optimisers). Different optimisers can be combined using [
Flux.Optimise.Optimiser
]( Flux.Optimiser).
This training loop iterates through
data
once. It will stop with a
DomainError
if the loss is
NaN
or infinite.
You can use
@epochs
to do this several times, or use for instance
Itertools.ncycle
to make a longer
data
iterator.
Callbacks are given with the keyword argument
cb
. For example, this will print "training" every 10 seconds (using
Flux.throttle
):
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
The callback can call
Flux.stop
to interrupt the training loop.
Multiple callbacks can be passed to
cb
as array.
There are
6
methods for Flux.Optimise.train!
:
The following pages link back here:
deprecations.jl , optimise/Optimise.jl , optimise/train.jl , train.jl