train!
function defined in module
Flux.Optimise
train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
Uses a
loss function and training
data to improve the model's parameters according to a particular optimisation rule
opt.
This method with implicit
Params will be removed from Flux 0.14. It should be replaced with the explicit method
train!(loss, model, data, opt).
For each
d in data, first the gradient of the
loss is computed like this:
gradient(() -> loss(d...), pars) # if d isa Tuple
gradient(() -> loss(d), pars) # otherwise
Here
pars is produced by calling
Flux.params on your model. (Or just on the layers you want to train, like
train!(loss, params(model[1:end-2]), data, opt).) This is the "implicit" style of parameter handling.
This gradient is then used by optimiser
opt to update the parameters:
update!(opt, pars, grads)
The optimiser should be from the
Flux.Optimise module (see
Optimisers). Different optimisers can be combined using [
Flux.Optimise.Optimiser]( Flux.Optimiser).
This training loop iterates through
data once. It will stop with a
DomainError if the loss is
NaN or infinite.
You can use
@epochs
to do this several times, or use for instance
Itertools.ncycle to make a longer
data iterator.
Callbacks are given with the keyword argument
cb. For example, this will print "training" every 10 seconds (using
Flux.throttle):
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
The callback can call
Flux.stop
to interrupt the training loop.
Multiple callbacks can be passed to
cb as array.
There are
6
methods for Flux.Optimise.train!:
The following pages link back here:
deprecations.jl , optimise/Optimise.jl , optimise/train.jl , train.jl