Training API Reference
The new version of Flux's training code was written as an independent package, Optimisers.jl. Only the function train! belongs to Flux itself.
The Optimisers package is designed to allow for immutable objects. But at present all Flux models contain parameter arrays (such as Arrays and CuArrays) which can be updated in-place. Because of this:
- The objects returned by
Optimisers.update!can be ignored. - Flux defines its own version of
setupwhich checks this assumption. (Using insteadOptimisers.setupwill also work, they return the same thing.)
The available optimization rules are listed the optimisation rules page here. See the Optimisers documentation for details on how the rules work.
Flux.Train.setup — Functionopt_state = setup(rule, model)This is a version of Optimisers.setup, and is the first step before using train!. It differs from Optimisers.setup in that it:
- has one extra check for mutability (since Flux expects to mutate the model in-place, while Optimisers.jl is designed to return an updated model)
- has methods which accept Flux's old optimisers, and convert them. (The old
Flux.Optimise.Adamand newOptimisers.Adamare distinct types.)
Example
julia> model = Dense(2 => 1, leakyrelu; init=ones);
julia> opt_state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state
(weight = Leaf(Momentum(eta=0.1, rho=0.9), [0.0 0.0]), bias = Leaf(Momentum(eta=0.1, rho=0.9), [0.0]), σ = ())
julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:
julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y
sum(abs.(m(x) .- y)) * 100
end
julia> model.bias # was zero, mutated by Flux.train!
1-element Vector{Float64}:
10.19
julia> opt_state # mutated by Flux.train!
(weight = Leaf(Momentum(eta=0.1, rho=0.9), [-2.018 3.027]), bias = Leaf(Momentum(eta=0.1, rho=0.9), [-10.09]), σ = ())opt_state = setup(rule, model::Duplicated) = setup(rule, model.val)Special method for use with Enzyme.jl, ignores the stored gradient.
Flux.Train.train! — Methodtrain!(loss, model, data, opt_state)Uses a loss function and training data to improve the model's parameters according to a particular optimisation rule encoded in opt_state. Iterates through data once, evaluating for each d in data either loss(model, d...) if d isa Tuple, or else loss(model, d) for other d.
If model is an Enzyme.Duplicated and Enzyme.jl is loaded, gradients will be computed with Enzyme, otherwise they will be computed with Zygote.
For example, with these definitions...
data = [(x1, y1), (x2, y2), (x3, y3)]
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
opt_state = Flux.setup(Adam(), model) # explicit setup of optimiser momenta...calling Flux.train!(loss3, model, data, opt_state) runs a loop much like this:
for d in data
∂L∂m = gradient(loss3, model, d...)[1]
update!(opt_state, model, ∂L∂m)
endYou can also write this loop yourself, if you need more flexibility. For this reason train! is not highly extensible. It adds only a few features to the loop above:
Stop with a
DomainErrorif the loss is infinite orNaNat any point.Show a progress bar using
@withprogress.
This method was added in Flux 0.13.9. It has significant changes from the one used by Flux ≤ 0.13:
- It now takes the
modelitself, not the result ofFlux.params. (This is to move away from Zygote's "implicit" parameter handling, withGrads.) - Instead of
lossbeing a function which accepts only the data, now it must also accept themodelitself, as the first argument. opt_stateshould be the result ofFlux.setup. Using an optimiser such asAdam()without this step should give you a warning.- Callback functions are not supported. (But any code can be included in the above
forloop.)
Optimisers.update — FunctionOptimisers.update(tree, model, gradient) -> (tree, model)Uses the optimiser and the gradient to change the trainable parameters in the model. Returns the improved model, and the optimiser states needed for the next update. The initial tree of states comes from setup.
See also update!, which will be faster for models of ordinary Arrays or CuArrays.
Example
julia> m = (x = Float32[1,2,3], y = tanh);
julia> t = Optimisers.setup(Descent(0.1), m)
(x = Leaf(Descent(0.1), nothing), y = ())
julia> g = (x = [1,1,1], y = nothing); # fake gradient
julia> Optimisers.update(t, m, g)
((x = Leaf(Descent(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))Optimisers.update! — FunctionOptimisers.update!(tree, model, gradient) -> (tree, model)Uses the optimiser and the gradient to change the trainable parameters in the model. Returns the improved model, and the optimiser states needed for the next update. The initial tree of states comes from setup.
This is used in exactly the same manner as update, but because it may mutate arrays within the old model (and the old state), it will be faster for models of ordinary Arrays or CuArrays. However, you should not rely on the old model being fully updated but rather use the returned model. (The original state tree is always mutated, as each Leaf is mutable.)
Example
julia> using StaticArrays, Zygote, Optimisers
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
(x = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]))
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
julia> t2, m2 = Optimisers.update!(t, m, g);
julia> m2 # after update or update!, this is the new model
(x = Float32[0.6666666, 1.5333333], y = Float32[3.6666667, 4.5333333])
julia> m2.x === m.x # update! has re-used this array, for efficiency
true
julia> m # original should be discarded, may be mutated but no guarantee
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
julia> t == t2 # original state tree is guaranteed to be mutated
trueOptimisers.setup — FunctionOptimisers.setup(rule, model) -> state_treeInitialises the given optimiser for every trainable parameter within the model. Returns a tree of the relevant states, which must be passed to update or update!.
Example
julia> m = (x = rand(3), y = (true, false), z = tanh);
julia> Optimisers.setup(Momentum(), m) # same field names as m
(x = Leaf(Momentum(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())The recursion into structures uses Functors.jl, and any new structs containing parameters need to be marked with Functors.@functor before use. See the Flux docs for more about this.
julia> struct Layer; mat; fun; end
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);
julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
(lay = (), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))
julia> using Functors; @functor Layer # annotate this type as containing parameters
julia> Optimisers.setup(Momentum(), model)
(lay = (mat = Leaf(Momentum(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))train! uses @progress which should show a progress bar in VSCode automatically. To see one in a terminal, you will need to install TerminalLoggers.jl and follow its setup instructions.
Optimisation Modifiers
The state returned by setup can be modified to temporarily prevent training of some parts of the model, or to change the learning rate or other hyperparameter. The functions for doing so may be accessed as Flux.freeze!, Flux.thaw!, and Flux.adjust!. All mutate the state (or part of it) and return nothing.
Optimisers.adjust! — FunctionOptimisers.adjust!(tree, η)Alters the state tree = setup(rule, model) to change the parameters of the optimisation rule, without destroying its stored state. Typically used mid-way through training.
Can be applied to part of a model, by acting only on the corresponding part of the state tree.
To change just the learning rate, provide a number η::Real.
Example
julia> m = (vec = rand(Float32, 2), fun = sin);
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
(vec = Leaf(Nesterov(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient
julia> st
(vec = Leaf(Nesterov(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
julia> Optimisers.adjust!(st, 0.123) # change learning rate, stored momentum untouched
julia> st
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())To change other parameters, adjust! also accepts keyword arguments matching the field names of the optimisation rule's type.
julia> fieldnames(Adam)
(:eta, :beta, :epsilon)
julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
(vec = Leaf(OptimiserChain(ClipGrad(10.0), Adam(0.001, (0.9, 0.999), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad
(vec = Leaf(OptimiserChain(ClipGrad(11.1), Adam(0.001, (0.777, 0.909), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
julia> Optimisers.adjust(st; beta = "no such field") # silently ignored!
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())Optimisers.freeze! — FunctionOptimisers.freeze!(tree)Temporarily alters the state tree = setup(rule, model) so that parameters will not be updated. Un-done by thaw!.
Can be applied to the state corresponding to only part of a model, for instance with model::Chain, to freeze model.layers[1] you should call freeze!(tree.layers[1]).
Example
julia> m = (x = ([1.0], 2.0), y = [3.0]);
julia> s = Optimisers.setup(Momentum(), m);
julia> Optimisers.freeze!(s.x)
julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient
julia> m
(x = ([1.0], 2.0), y = [-0.14159265358979312])
julia> s
(x = (Leaf(Momentum(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum(0.01, 0.9), [3.14159]))
julia> Optimisers.thaw!(s)
julia> s.x
(Leaf(Momentum(0.01, 0.9), [0.0]), ())Optimisers.thaw! — FunctionOptimisers.thaw!(tree)The reverse of freeze!. Applies to all parameters, mutating every Leaf(rule, state, frozen = true) to Leaf(rule, state, frozen = false).