Optimisation Rules
Optimisers.Descent
— TypeDescent(η = 1f-1)
Descent(; [eta])
Classic gradient descent optimiser with learning rate η
. For each parameter p
and its gradient dp
, this runs p -= η*dp
.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights.
Optimisers.Momentum
— TypeMomentum(η = 0.01, ρ = 0.9)
Momentum(; [eta, rho])
Gradient descent optimizer with learning rate η
and momentum ρ
.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Momentum (
ρ == rho
): Controls the acceleration of gradient descent in the prominent direction, in effect dampening oscillations.
Optimisers.Nesterov
— TypeNesterov(η = 0.001, ρ = 0.9)
Nesterov(; [eta, rho])
Gradient descent optimizer with learning rate η
and Nesterov momentum ρ
.
Parameters
- Learning rate (
η
): Amount by which gradients are discounted before updating the weights. - Nesterov momentum (
ρ
): Controls the acceleration of gradient descent in the prominent direction, in effect dampening oscillations.
Optimisers.Rprop
— TypeRprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0))
Rprop(; [eta, ell, gamma])
Optimizer using the Rprop algorithm. A full-batch learning algorithm that depends only on the sign of the gradient.
Parameters
Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights.Scaling factors (
ℓ::Tuple == ell
): Multiplicative increase and decrease factors.Step sizes (
Γ::Tuple == gamma
): Mminimal and maximal allowed step sizes.
Optimisers.RMSProp
— TypeRMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred = false)
RMSProp(; [eta, rho, epsilon, centred])
Optimizer using the RMSProp algorithm. Often a good choice for recurrent networks. Parameters other than learning rate generally don't need tuning.
Centred RMSProp is a variant which normalises gradients by an estimate their variance, instead of their second moment.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Momentum (
ρ == rho
): Controls the acceleration of gradient descent in the prominent direction, in effect dampening oscillations. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default) - Keyword
centred
(orcentered
): Indicates whether to use centred variant of the algorithm.
Optimisers.Adam
— TypeAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
Adam(; [eta, beta, epsilon])
Adam optimiser.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Decay of momentums (
β::Tuple == beta
): Exponential decay for the first (β1) and the second (β2) momentum estimate. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default)
Optimisers.RAdam
— TypeRAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
RAdam(; [eta, beta, epsilon])
Rectified Adam optimizer.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Decay of momentums (
β::Tuple == beta
): Exponential decay for the first (β1) and the second (β2) momentum estimate. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default)
Optimisers.AdaMax
— TypeAdaMax(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
AdaMax(; [eta, beta, epsilon])
AdaMax is a variant of Adam based on the ∞-norm.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Decay of momentums (
β::Tuple == beta
): Exponential decay for the first (β1) and the second (β2) momentum estimate. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default)
Optimisers.OAdam
— TypeOAdam(η = 0.001, β = (0.5, 0.9), ϵ = 1e-8)
OAdam(; [eta, beta, epsilon])
OAdam (Optimistic Adam) is a variant of Adam adding an "optimistic" term suitable for adversarial training.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Decay of momentums (
β::Tuple == beta
): Exponential decay for the first (β1) and the second (β2) momentum estimate. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default)
Optimisers.AdaGrad
— TypeAdaGrad(η = 0.1, ϵ = 1e-8)
AdaGrad(; [eta, epsilon])
AdaGrad optimizer. It has parameter specific learning rates based on how frequently it is updated. Parameters don't need tuning.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default)
Optimisers.AdaDelta
— TypeAdaDelta(ρ = 0.9, ϵ = 1e-8)
AdaDelta(; [rho, epsilon])
AdaDelta is a version of AdaGrad adapting its learning rate based on a window of past gradient updates. Parameters don't need tuning.
Parameters
- Rho (
ρ == rho
): Factor by which the gradient is decayed at each time step. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default)
Optimisers.AMSGrad
— TypeAMSGrad(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
AMSGrad(; [eta, beta, epsilon])
The AMSGrad version of the Adam optimiser. Parameters don't need tuning.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Decay of momentums (
β::Tuple == beta
): Exponential decay for the first (β1) and the second (β2) momentum estimate. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default)
Optimisers.NAdam
— TypeNAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)
NAdam(; [eta, beta, epsilon])
NAdam is a Nesterov variant of Adam. Parameters don't need tuning.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Decay of momentums (
β::Tuple == beta
): Exponential decay for the first (β1) and the second (β2) momentum estimate. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default)
Optimisers.AdamW
— TypeAdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8; couple = true)
AdamW(; [eta, beta, lambda, epsilon, couple])
AdamW is a variant of Adam fixing (as in repairing) its weight decay regularization. Implemented as an OptimiserChain
of Adam
and WeightDecay
`.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Decay of momentums (
β::Tuple == beta
): Exponential decay for the first (β1) and the second (β2) momentum estimate. - Weight decay (
λ == lambda
): Controls the strength of $L_2$ regularisation. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default) - Keyword
couple
: Iftrue
, the weight decay is coupled with the learning rate, as in pytorch's AdamW. This corresponds to an update of the formx = x - η * (dx + λ * x)
, wheredx
is the update from Adam with learning rate 1. Iffalse
, the weight decay is decoupled from the learning rate, in the spirit of the original paper. This corresponds to an update of the formx = x - η * dx - λ * x
. Default istrue
.
With version 0.4 the default update rule for AdamW has changed to match the pytorch implementation. The previous rule, which is closer to the original paper, can be obtained by setting AdamW(..., couple=false)
. See this issue for more details.
Optimisers.AdaBelief
— TypeAdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
AdaBelief(; [eta, beta, epsilon])
The AdaBelief optimiser is a variant of the well-known Adam optimiser.
Parameters
- Learning rate (
η == eta
): Amount by which gradients are discounted before updating the weights. - Decay of momentums (
β::Tuple == beta
): Exponential decay for the first (β1) and the second (β2) momentum estimate. - Machine epsilon (
ϵ == epsilon
): Constant to prevent division by zero (no need to change default)
Optimisers.Lion
— TypeLion(η = 0.001, β = (0.9, 0.999))
Lion(; [eta, beta])
Lion optimiser.
Parameters
- Learning rate (
η == eta
): Magnitude by which gradients are updating the weights. - Decay of momentums (
β::Tuple == beta
): Exponential decay for the first (β1) and the second (β2) momentum estimate.
In addition to the main course, you may wish to order some of these condiments:
Optimisers.AccumGrad
— TypeAccumGrad(n::Int)
A rule constructed OptimiserChain(AccumGrad(n), Rule())
will accumulate for n
steps, before applying Rule
to the mean of these n
gradients.
This is useful for training with effective batch sizes too large for the available memory. Instead of computing the gradient for batch size b
at once, compute it for size b/n
and accumulate n
such gradients.
Example
julia> m = (x=[1f0], y=[2f0]);
julia> r = OptimiserChain(AccumGrad(2), WeightDecay(0.01), Descent(0.1));
julia> s = Optimisers.setup(r, m);
julia> Optimisers.update!(s, m, (x=[33], y=[0]));
julia> m # model not yet changed
(x = Float32[1.0], y = Float32[2.0])
julia> Optimisers.update!(s, m, (x=[0], y=[444]));
julia> m # n=2 gradients applied at once
(x = Float32[-0.651], y = Float32[-20.202002])
Optimisers.ClipGrad
— TypeClipGrad(δ = 10)
ClipGrad(; [delta])
Restricts every gradient component to obey -δ ≤ dx[i] ≤ δ
.
Typically composed with other rules using OptimiserChain
.
See also ClipNorm
.
Optimisers.ClipNorm
— TypeClipNorm(ω = 10, p = 2; throw = true)
Scales any gradient array for which norm(dx, p) > ω
to stay at this threshold (unless p==0
).
Throws an error if the norm is infinite or NaN
, which you can turn off with throw = false
.
Typically composed with other rules using OptimiserChain
.
See also ClipGrad
.
Optimisers.SignDecay
— TypeSignDecay(λ = 1e-3)
SignDecay(; [lambda])
Implements $L_1$ regularisation, also known as LASSO regression, when composed with other rules as the first transformation in an OptimiserChain
.
It does this by adding λ .* sign(x)
to the gradient. This is equivalent to adding λ * sum(abs, x) == λ * norm(x, 1)
to the loss.
See also [WeightDecay
] for $L_2$ normalisation. They can be used together: OptimiserChain(SignDecay(0.012), WeightDecay(0.034), Adam())
is equivalent to adding 0.012 * norm(x, 1) + 0.017 * norm(x, 2)^2
to the loss function.
Parameters
- Penalty (
λ ≥ 0
): Controls the strength of the regularisation.
Optimisers.WeightDecay
— TypeWeightDecay(λ = 5e-4)
WeightDecay(; [lambda])
Implements $L_2$ regularisation, also known as ridge regression, when composed with other rules as the first transformation in an OptimiserChain
.
It does this by adding λ .* x
to the gradient. This is equivalent to adding λ/2 * sum(abs2, x) == λ/2 * norm(x)^2
to the loss.
See also [SignDecay
] for $L_1$ normalisation.
Parameters
- Penalty (
λ ≥ 0
): Controls the strength of the regularisation.
Optimisers.OptimiserChain
— TypeOptimiserChain(opts...)
Compose a sequence of optimisers so that each opt
in opts
updates the gradient, in the order specified.
With an empty sequence, OptimiserChain()
is the identity, so update!
will subtract the full gradient from the parameters. This is equivalent to Descent(1)
.
Example
julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
julia> m = (zeros(3),);
julia> s = Optimisers.setup(o, m)
(Leaf(OptimiserChain(ClipGrad(1.0), Descent(0.1)), (nothing, nothing)),)
julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
([-0.03, -0.1, -0.1],)
Model Interface
Optimisers.setup
— FunctionOptimisers.setup(rule, model) -> state_tree
Initialises the given optimiser for every trainable parameter within the model. Returns a tree of the relevant states, which must be passed to update
or update!
.
Example
julia> m = (x = rand(3), y = (true, false), z = tanh);
julia> Optimisers.setup(Momentum(), m) # same field names as m
(x = Leaf(Momentum(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
The recursion into structures uses Functors.jl, and any new struct
s containing parameters need to be marked with Functors.@functor
before use. See the Flux docs for more about this.
julia> struct Layer; mat; fun; end
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);
julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
(lay = (), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))
julia> using Functors; @functor Layer # annotate this type as containing parameters
julia> Optimisers.setup(Momentum(), model)
(lay = (mat = Leaf(Momentum(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
Optimisers.update
— FunctionOptimisers.update(tree, model, gradient) -> (tree, model)
Uses the optimiser and the gradient to change the trainable parameters in the model. Returns the improved model, and the optimiser states needed for the next update. The initial tree of states comes from setup
.
See also update!
, which will be faster for models of ordinary Array
s or CuArray
s.
Example
julia> m = (x = Float32[1,2,3], y = tanh);
julia> t = Optimisers.setup(Descent(0.1), m)
(x = Leaf(Descent(0.1), nothing), y = ())
julia> g = (x = [1,1,1], y = nothing); # fake gradient
julia> Optimisers.update(t, m, g)
((x = Leaf(Descent(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
Optimisers.update!
— FunctionOptimisers.update!(tree, model, gradient) -> (tree, model)
Uses the optimiser and the gradient to change the trainable parameters in the model. Returns the improved model, and the optimiser states needed for the next update. The initial tree of states comes from setup
.
This is used in exactly the same manner as update
, but because it may mutate arrays within the old model (and the old state), it will be faster for models of ordinary Array
s or CuArray
s. However, you should not rely on the old model being fully updated but rather use the returned model. (The original state tree is always mutated, as each Leaf
is mutable.)
Example
julia> using StaticArrays, Zygote, Optimisers
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
(x = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]))
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
julia> t2, m2 = Optimisers.update!(t, m, g);
julia> m2 # after update or update!, this is the new model
(x = Float32[0.6666666, 1.5333333], y = Float32[3.6666667, 4.5333333])
julia> m2.x === m.x # update! has re-used this array, for efficiency
true
julia> m # original should be discarded, may be mutated but no guarantee
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
julia> t == t2 # original state tree is guaranteed to be mutated
true
Optimisers.adjust!
— FunctionOptimisers.adjust!(tree, η)
Alters the state tree = setup(rule, model)
to change the parameters of the optimisation rule, without destroying its stored state. Typically used mid-way through training.
Can be applied to part of a model, by acting only on the corresponding part of the state tree
.
To change just the learning rate, provide a number η::Real
.
Example
julia> m = (vec = rand(Float32, 2), fun = sin);
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
(vec = Leaf(Nesterov(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient
julia> st
(vec = Leaf(Nesterov(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
julia> Optimisers.adjust!(st, 0.123) # change learning rate, stored momentum untouched
julia> st
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
To change other parameters, adjust!
also accepts keyword arguments matching the field names of the optimisation rule's type.
julia> fieldnames(Adam)
(:eta, :beta, :epsilon)
julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
(vec = Leaf(OptimiserChain(ClipGrad(10.0), Adam(0.001, (0.9, 0.999), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad
(vec = Leaf(OptimiserChain(ClipGrad(11.1), Adam(0.001, (0.777, 0.909), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
julia> Optimisers.adjust(st; beta = "no such field") # silently ignored!
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
Optimisers.adjust
— Methodadjust(tree, η) -> tree
Like adjust!
, but returns a new tree instead of mutating the old one.
Optimisers.freeze!
— FunctionOptimisers.freeze!(tree)
Temporarily alters the state tree = setup(rule, model)
so that parameters will not be updated. Un-done by thaw!
.
Can be applied to the state corresponding to only part of a model, for instance with model::Chain
, to freeze model.layers[1]
you should call freeze!(tree.layers[1])
.
Example
julia> m = (x = ([1.0], 2.0), y = [3.0]);
julia> s = Optimisers.setup(Momentum(), m);
julia> Optimisers.freeze!(s.x)
julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient
julia> m
(x = ([1.0], 2.0), y = [-0.14159265358979312])
julia> s
(x = (Leaf(Momentum(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum(0.01, 0.9), [3.14159]))
julia> Optimisers.thaw!(s)
julia> s.x
(Leaf(Momentum(0.01, 0.9), [0.0]), ())
Optimisers.thaw!
— FunctionOptimisers.thaw!(tree)
The reverse of freeze!
. Applies to all parameters, mutating every Leaf(rule, state, frozen = true)
to Leaf(rule, state, frozen = false)
.
Calling Functors.@functor
on your model's layer types by default causes these functions to recurse into all children, and ultimately optimise all isnumeric
leaf nodes. To further restrict this by ignoring some fields of a layer type, define trainable
:
Optimisers.trainable
— Functiontrainable(x::Layer) -> NamedTuple
This may be overloaded to make optimisers ignore some fields of every Layer
, which would otherwise contain trainable parameters.
This is very rarely required. Fields of struct Layer
which contain functions, or integers like sizes, are always ignored anyway. Overloading trainable
is only necessary when some arrays of numbers are to be optimised, and some arrays of numbers are not.
The default is Functors.children(x)
, usually a NamedTuple of all fields, and trainable(x)
must contain a subset of these.
Optimisers.isnumeric
— Functionisnumeric(x) -> Bool
Returns true
on any parameter to be adjusted by Optimisers.jl, namely arrays of non-integer numbers. Returns false
on all other types.
Requires also that Functors.isleaf(x) == true
, to focus on e.g. the parent of a transposed matrix, not the wrapper.
Optimisers.maywrite
— Functionmaywrite(x) -> Bool
Should return true
if we are completely sure that update!
can write new values into x
. Otherwise false
, indicating a non-mutating path. For now, simply x isa DenseArray
allowing Array
, CuArray
, etc.
Such restrictions are also obeyed by this function for flattening a model:
Optimisers.destructure
— Functiondestructure(model) -> vector, reconstructor
Copies all trainable
, isnumeric
parameters in the model to a vector, and returns also a function which reverses this transformation. Differentiable.
Example
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))
julia> re([3, 5, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))
If model
contains various number types, they are promoted to make vector
, and are usually restored by Restructure
. Such restoration follows the rules of ChainRulesCore.ProjectTo
, and thus will restore floating point precision, but will permit more exotic numbers like ForwardDiff.Dual
.
If model
contains only GPU arrays, then vector
will also live on the GPU. At present, a mixture of GPU and ordinary CPU arrays is undefined behaviour.
Optimisers.Restructure
— TypeRestructure(Model, ..., length)
This is what destructure
returns, and re(p)
will re-build the model with new parameters from vector p
. If the model is callable, then re(x, p) == re(p)(x)
.
Example
julia> using Flux, Optimisers
julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid))
([1, 3, 2, 4, 0, 0], Restructure(Dense, ..., 6))
julia> m = re(-4:1)
Dense(2, 2, σ) # 6 parameters
julia> m([0.2, 0.3]) ≈ re([0.2, 0.3], -4:1)
true
Optimisers.trainables
— Functiontrainables(x, path = false)
Return an iterable over all the trainable parameters in x
, that is all the numerical arrays (see isnumeric
) which are reachable through trainable
.
Parameters appearing multiple times in the model (tied weights) will be present only once in the output.
If path = false
, the output is a list of numerical arrays.
If path = true
, the output is a list of (KeyPath, AbstractArray)
pairs, where KeyPath
is a type representing the path to the array in the original structure.
See also destructure
for a similar operation that returns a single flat vector instead.
Examples
julia> struct MyLayer
w
b
end
julia> Functors.@functor MyLayer
julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example
julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]);
julia> trainables(x)
1-element Vector{AbstractArray}:
[1.0, 2.0, 3.0]
julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);
julia> trainables(x) # collects nested parameters
2-element Vector{AbstractArray}:
[1.0, 2.0]
[3.0]
julia> x = (a = [1.0,2.0], b = (Dict("c" => [3.0, 4.0], "d" => 5.0), [6.0,7.0]));
julia> for (kp, y) in trainables(x, path = true)
println(kp, " => ", y)
end
KeyPath(:a,) => [1.0, 2.0]
KeyPath(:b, 1, "c") => [3.0, 4.0]
KeyPath(:b, 2) => [6.0, 7.0]
julia> getkeypath(x, KeyPath(:b, 1, "c"))
2-element Vector{Float64}:
3.0
4.0
Rule Definition
Optimisers.apply!
— FunctionOptimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient)
This defines the action of any optimisation rule. It should return the modified gradient which will be subtracted from the parameters, and the updated state (if any) for use at the next iteration, as a tuple (state, gradient)
.
For efficiency it is free to mutate the old state, but only what is returned will be used. Ideally this should check maywrite(x)
, which the built-in rules do via @..
.
The initial state is init(rule::RuleType, parameters)
.
Example
julia> Optimisers.init(Descent(0.1), Float32[1,2,3]) === nothing
true
julia> Optimisers.apply!(Descent(0.1), nothing, Float32[1,2,3], [4,5,6])
(nothing, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(*, ([4, 5, 6], 0.1f0)))
Optimisers.init
— FunctionOptimisers.init(rule::RuleType, parameters) -> state
Sets up the initial state for a given optimisation rule, and an array of parameters. This and apply!
are the two functions which any new optimisation rule must define.
Examples
julia> Optimisers.init(Descent(), Float32[1,2,3]) # is `nothing`
julia> Optimisers.init(Momentum(), [1.0, 2.0])
2-element Vector{Float64}:
0.0
0.0
Optimisers.@..
— Macro@.. x = y + z
Sometimes in-place broadcasting macro, for use in apply!
rules. If maywrite(x)
then it is just @. x = rhs
, but if not, it becomes x = @. rhs
.
Optimisers.@lazy
— Macrox = @lazy y + z
Lazy broadcasting macro, for use in apply!
rules. It broadcasts like @.
but does not materialise, returning a Broadcasted
object for later use. Beware that mutation of arguments will affect the result, and that if it is used in two places, work will be done twice.
Optimisers.adjust
— MethodOptimisers.adjust(rule::RuleType, η::Real) -> rule
If a new optimisation rule has a learning rate which is not stored in field rule.eta
, then you may should add a method to adjust
. (But simpler to just use the standard name.)
Optimisers.@def
— Macro@def struct Rule; eta = 0.1; beta = (0.7, 0.8); end
Helper macro for defining rules with default values. The types of the literal values are used in the struct
, like this:
struct Rule
eta::Float64
beta::Tuple{Float64, Float64}
Rule(eta, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
Rule(; eta = 0.1, beta = (0.7, 0.8)) = Rule(eta, beta)
end
Any field called eta
is assumed to be a learning rate, and cannot be negative.
KeyPath
A KeyPath
is a sequence of keys that can be used to access a value within a nested structure. It is defined in Functors.jl and re-exported by Optimisers.jl here for convenience.
Functors.KeyPath
— TypeKeyPath(keys...)
A type for representing a path of keys to a value in a nested structure. Can be constructed with a sequence of keys, or by concatenating other KeyPath
s. Keys can be of type Symbol
, String
, or Int
.
For custom types, access through symbol keys is assumed to be done with getproperty
. For consistency, the method Base.propertynames
is used to get the viable property names.
For string and integer keys instead, the access is done with getindex
.
See also getkeypath
, haskeypath
.
Examples
julia> kp = KeyPath(:b, 3)
KeyPath(:b, 3)
julia> KeyPath(:a, kp, :c, 4) # construct mixing keys and keypaths
KeyPath(:a, :b, 3, :c, 4)
julia> struct T
a
b
end
julia> function Base.getproperty(x::T, k::Symbol)
if k in fieldnames(T)
return getfield(x, k)
elseif k === :ab
return "ab"
else
error()
end
end;
julia> Base.propertynames(::T) = (:a, :b, :ab);
julia> x = T(3, Dict(:c => 4, :d => 5));
julia> getkeypath(x, KeyPath(:ab)) # equivalent to x.ab
"ab"
julia> getkeypath(x, KeyPath(:b, :c)) # equivalent to (x.b)[:c]
4
Functors.haskeypath
— Functionhaskeypath(x, kp::KeyPath)
Return true
if x
has a value at the path kp
.
See also KeyPath
, getkeypath
, and setkeypath!
.
Examples
julia> x = Dict(:a => 3, :b => Dict(:c => 4, "d" => [5, 6, 7]))
Dict{Symbol, Any} with 2 entries:
:a => 3
:b => Dict{Any, Any}(:c=>4, "d"=>[5, 6, 7])
julia> haskeypath(x, KeyPath(:a))
true
julia> haskeypath(x, KeyPath(:b, "d", 1))
true
julia> haskeypath(x, KeyPath(:b, "d", 4))
false
Functors.getkeypath
— Functiongetkeypath(x, kp::KeyPath)
Return the value in x
at the path kp
.
See also KeyPath
, haskeypath
, and setkeypath!
.
Examples
julia> x = Dict(:a => 3, :b => Dict(:c => 4, "d" => [5, 6, 7]))
Dict{Symbol, Any} with 2 entries:
:a => 3
:b => Dict{Any, Any}(:c=>4, "d"=>[5, 6, 7])
julia> getkeypath(x, KeyPath(:b, "d", 2))
6
Functors.setkeypath!
— Functionsetkeypath!(x, kp::KeyPath, v)
Set the value in x
at the path kp
to v
.
See also KeyPath
, getkeypath
, and haskeypath
.