Training API Reference
The new version of Flux's training code was written as an independent package, Optimisers.jl. Only the function train!
belongs to Flux itself.
The Optimisers package is designed to allow for immutable objects. But at present all Flux models contain parameter arrays (such as Array
s and CuArray
s) which can be updated in-place. Because of this:
- The objects returned by
Optimisers.update!
can be ignored. - Flux defines its own version of
setup
which checks this assumption. (Using insteadOptimisers.setup
will also work, they return the same thing.)
Flux.Train.setup
— Functionopt_state = setup(rule, model)
This is a version of Optimisers.setup
, and is the first step before using train!
. It differs from Optimisers.setup
in that it:
- has one extra check for mutability (since Flux expects to mutate the model in-place, while Optimisers.jl is designed to return an updated model)
- has methods which accept Flux's old optimisers, and convert them. (The old
Flux.Optimise.Adam
and newOptimisers.Adam
are distinct types.)
This function was added in Flux 0.13.9. It was not used by the old "implicit" interface, using Flux.Optimise
module and Flux.params
.
Example
julia> model = Dense(2=>1, leakyrelu; init=ones);
julia> opt_state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [0.0]), σ = ())
julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:
julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y
sum(abs.(m(x) .- y)) * 100
end
julia> model.bias # was zero, mutated by Flux.train!
1-element Vector{Float64}:
10.19
julia> opt_state # mutated by Flux.train!
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [-10.09]), σ = ())
Flux.Optimise.train!
— Methodtrain!(loss, model, data, opt_state)
Uses a loss
function and training data
to improve the model
's parameters according to a particular optimisation rule encoded in opt_state
. Iterates through data
once, evaluating for each d in data
either loss(model, d...)
if d isa Tuple
, or else loss(model, d)
for other d
.
For example, with these definitions...
data = [(x1, y1), (x2, y2), (x3, y3)]
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
opt_state = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
...calling Flux.train!(loss3, model, data, opt_state)
runs a loop much like this:
for d in data
∂L∂m = gradient(loss3, model, d...)[1]
update!(opt_state, model, ∂L∂m)
end
You can also write this loop yourself, if you need more flexibility. For this reason train!
is not highly extensible. It adds only a few features to the loop above:
Stop with a
DomainError
if the loss is infinite orNaN
at any point.Show a progress bar using
@withprogress
.
This method was added in Flux 0.13.9. It has significant changes from the one used by Flux ≤ 0.13:
- It now takes the
model
itself, not the result ofFlux.params
. (This is to move away from Zygote's "implicit" parameter handling, withGrads
.) - Instead of
loss
being a function which accepts only the data, now it must also accept themodel
itself, as the first argument. opt_state
should be the result ofFlux.setup
. Using an optimiser such asAdam()
without this step should give you a warning.- Callback functions are not supported. (But any code can be included in the above
for
loop.)
Optimisers.update!
— FunctionOptimisers.update!(tree, model, gradient) -> (tree, model)
Uses the optimiser and the gradient to change the trainable parameters in the model. Returns the improved model, and the optimiser states needed for the next update. The initial tree of states comes from setup
.
This is used in exactly the same manner as update
, but because it may mutate arrays within the old model (and the old state), it will be faster for models of ordinary Array
s or CuArray
s. However, you should not rely on the old model being fully updated but rather use the returned model. (The original state tree is always mutated, as each Leaf
is mutable.)
Example
julia> using StaticArrays, Zygote, Optimisers
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
(x = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]))
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
julia> t2, m2 = Optimisers.update!(t, m, g);
julia> m2 # after update or update!, this is the new model
(x = Float32[0.6666666, 1.5333333], y = Float32[3.6666667, 4.5333333])
julia> m2.x === m.x # update! has re-used this array, for efficiency
true
julia> m # original should be discarded, may be mutated but no guarantee
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
julia> t == t2 # original state tree is guaranteed to be mutated
true
train!
uses @progress
which should show a progress bar in VSCode automatically. To see one in a terminal, you will need to install TerminalLoggers.jl and follow its setup instructions.
Optimisation Modifiers
The state returned by setup
can be modified to temporarily prevent training of some parts of the model, or to change the learning rate or other hyperparameter. The functions for doing so may be accessed as Flux.freeze!
, Flux.thaw!
, and Flux.adjust!
. All mutate the state (or part of it) and return nothing
.
Optimisers.adjust!
— FunctionOptimisers.adjust!(tree, η)
Alters the state tree = setup(rule, model)
to change the parameters of the optimisation rule, without destroying its stored state. Typically used mid-way through training.
Can be applied to part of a model, by acting only on the corresponding part of the state tree
.
To change just the learning rate, provide a number η::Real
.
Example
julia> m = (vec = rand(Float32, 2), fun = sin);
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
(vec = Leaf(Nesterov(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient
julia> st
(vec = Leaf(Nesterov(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
julia> Optimisers.adjust!(st, 0.123) # change learning rate, stored momentum untouched
julia> st
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
To change other parameters, adjust!
also accepts keyword arguments matching the field names of the optimisation rule's type.
julia> fieldnames(Adam)
(:eta, :beta, :epsilon)
julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
(vec = Leaf(OptimiserChain(ClipGrad(10.0), Adam(0.001, (0.9, 0.999), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad
(vec = Leaf(OptimiserChain(ClipGrad(11.1), Adam(0.001, (0.777, 0.909), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
julia> Optimisers.adjust(st; beta = "no such field") # silently ignored!
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
Optimisers.freeze!
— FunctionOptimisers.freeze!(tree)
Temporarily alters the state tree = setup(rule, model)
so that parameters will not be updated. Un-done by thaw!
.
Can be applied to the state corresponding to only part of a model, for instance with model::Chain
, to freeze model.layers[1]
you should call freeze!(tree.layers[1])
.
Example
julia> m = (x = ([1.0], 2.0), y = [3.0]);
julia> s = Optimisers.setup(Momentum(), m);
julia> Optimisers.freeze!(s.x)
julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient
julia> m
(x = ([1.0], 2.0), y = [-0.14159265358979312])
julia> s
(x = (Leaf(Momentum(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum(0.01, 0.9), [3.14159]))
julia> Optimisers.thaw!(s)
julia> s.x
(Leaf(Momentum(0.01, 0.9), [0.0]), ())
Optimisers.thaw!
— FunctionOptimisers.thaw!(tree)
The reverse of freeze!
. Applies to all parameters, mutating every Leaf(rule, state, frozen = true)
to Leaf(rule, state, frozen = false)
.
Implicit style (Flux ≤ 0.14)
Flux used to handle gradients, training, and optimisation rules quite differently. The new style described above is called "explicit" by Zygote, and the old style "implicit". Flux 0.13 and 0.14 are the transitional versions which support both; Flux 0.15 will remove the old.
The blue-green boxes in the training section describe the changes needed to upgrade old code.
The available rules are listed the optimisation rules page here.
The new implementation of rules such as Adam in the Optimisers is quite different from the old one in Flux.Optimise
. In Flux 0.14, Flux.Adam()
still returns the old one, with supertype Flux.Optimise.AbstractOptimiser
, but setup
will silently translate it to its new counterpart.
For full details on the interface for implicit-style optimisers, see the Flux 0.13.6 manual. See the Optimisers documentation for details on how the new rules work.
Much earlier versions of Flux exported params
, thus allowing unqualified params(model)
after using Flux
. This conflicted with too many other packages, and was removed in Flux 0.13. If you get an error UndefVarError: params not defined
, this probably means that you are following code for Flux 0.12 or earlier on a more recent version.
Flux.params
— Functionparams(model)
params(layers...)
Given a model or specific layers from a model, create a Params
object pointing to its trainable parameters.
This can be used with the gradient
function, see the training section of the manual, or as input to the Flux.train!
function.
The behaviour of params
on custom types can be customized using Functors.@functor
or Flux.trainable
.
Examples
julia> using Flux: params
julia> params(Chain(Dense(ones(2,3)), softmax)) # unpacks Flux models
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])
julia> bn = BatchNorm(2, relu)
BatchNorm(2, relu) # 4 parameters, plus 4 non-trainable
julia> params(bn) # only the trainable parameters
Params([Float32[0.0, 0.0], Float32[1.0, 1.0]])
julia> params([1, 2, 3], [4]) # one or more arrays of numbers
Params([[1, 2, 3], [4]])
julia> params([[1, 2, 3], [4]]) # unpacks array of arrays
Params([[1, 2, 3], [4]])
julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin)) # ignores scalars, unpacks NamedTuples
Params([[2 2], [3, 3, 3]])
Optimisers.update!
— Methodupdate!(opt, p, g)
update!(opt, ps::Params, gs)
Perform an update step of the parameters ps
(or the single parameter p
) according to optimiser opt::AbstractOptimiser
and the gradients gs
(the gradient g
).
As a result, the parameters are mutated and the optimiser's internal state may change. The gradient could be mutated as well.
This method for implicit Params
(and AbstractOptimiser
) will be removed from Flux 0.15. The explicit method update!(opt, model, grad)
from Optimisers.jl will remain.
Flux.Optimise.train!
— Methodtrain!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
Uses a loss
function and training data
to improve the model's parameters according to a particular optimisation rule opt
.
This method with implicit Params
will be removed from Flux 0.15. It should be replaced with the explicit method train!(loss, model, data, opt)
.
For each d in data
, first the gradient of the loss
is computed like this:
gradient(() -> loss(d...), pars) # if d isa Tuple
gradient(() -> loss(d), pars) # otherwise
Here pars
is produced by calling Flux.params
on your model. (Or just on the layers you want to train, like train!(loss, params(model[1:end-2]), data, opt)
.) This is the "implicit" style of parameter handling.
This gradient is then used by optimiser opt
to update the parameters:
update!(opt, pars, grads)
The optimiser should be from the Flux.Optimise
module (see Optimisers). Different optimisers can be combined using Flux.Optimise.Optimiser
.
This training loop iterates through data
once. It will stop with a DomainError
if the loss is NaN
or infinite.
You can use use train!
inside a for loop to do this several times, or use for instance Itertools.ncycle
to make a longer data
iterator.
Callbacks
Callbacks are given with the keyword argument cb
. For example, this will print "training" every 10 seconds (using Flux.throttle
):
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
Multiple callbacks can be passed to cb
as array.
Callbacks
Implicit train!
takes an additional argument, cb
, that's used for callbacks so that you can observe the training process. For example:
train!(objective, ps, data, opt, cb = () -> println("training"))
Callbacks are called for every batch of training data. You can slow this down using Flux.throttle(f, timeout)
which prevents f
from being called more than once every timeout
seconds.
A more typical callback might look like this:
test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))
throttled_cb = throttle(evalcb, 5)
for epoch in 1:20
@info "Epoch $epoch"
Flux.train!(objective, ps, data, opt, cb = throttled_cb)
end
See the page about callback helpers for more.