Optimisation Rules

Optimisers.DescentType
Descent(η = 1f-1)
Descent(; eta)

Classic gradient descent optimiser with learning rate η. For each parameter p and its gradient dp, this runs p -= η*dp.

Parameters

  • Learning rate (η == eta): Amount by which gradients are discounted before updating the weights.
source
Optimisers.MomentumType
Momentum(η = 0.01, ρ = 0.9)
Momentum(; [eta, rho])

Gradient descent optimizer with learning rate η and momentum ρ.

Parameters

  • Learning rate (η == eta): Amount by which gradients are discounted before updating the weights.
  • Momentum (ρ == rho): Controls the acceleration of gradient descent in the prominent direction, in effect dampening oscillations.
source
Optimisers.NesterovType
Nesterov(η = 0.001, ρ = 0.9)

Gradient descent optimizer with learning rate η and Nesterov momentum ρ.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.
  • Nesterov momentum (ρ): Controls the acceleration of gradient descent in the prominent direction, in effect dampening oscillations.
source
Optimisers.RpropType
Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0))

Optimizer using the Rprop algorithm. A full-batch learning algorithm that depends only on the sign of the gradient.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.

  • Scaling factors (ℓ::Tuple): Multiplicative increase and decrease factors.

  • Step sizes (Γ::Tuple): Mminimal and maximal allowed step sizes.

source
Optimisers.RMSPropType
RMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred = false)
RMSProp(; [eta, rho, epsilon, centred])

Optimizer using the RMSProp algorithm. Often a good choice for recurrent networks. Parameters other than learning rate generally don't need tuning.

Centred RMSProp is a variant which normalises gradients by an estimate their variance, instead of their second moment.

Parameters

  • Learning rate (η == eta): Amount by which gradients are discounted before updating the weights.
  • Momentum (ρ == rho): Controls the acceleration of gradient descent in the prominent direction, in effect dampening oscillations.
  • Machine epsilon (ϵ == epsilon): Constant to prevent division by zero (no need to change default)
  • Keyword centred (or centered): Indicates whether to use centred variant of the algorithm.
source
Optimisers.AdamType
Adam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)

Adam optimiser.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.
  • Decay of momentums (β::Tuple): Exponential decay for the first (β1) and the second (β2) momentum estimate.
  • Machine epsilon (ϵ): Constant to prevent division by zero (no need to change default)
source
Optimisers.RAdamType
RAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)

Rectified Adam optimizer.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.
  • Decay of momentums (β::Tuple): Exponential decay for the first (β1) and the second (β2) momentum estimate.
  • Machine epsilon (ϵ): Constant to prevent division by zero (no need to change default)
source
Optimisers.AdaMaxType
AdaMax(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)

AdaMax is a variant of Adam based on the ∞-norm.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.
  • Decay of momentums (β::Tuple): Exponential decay for the first (β1) and the second (β2) momentum estimate.
  • Machine epsilon (ϵ): Constant to prevent division by zero (no need to change default)
source
Optimisers.OAdamType
OAdam(η = 0.001, β = (0.5, 0.9), ϵ = 1e-8)

OAdam (Optimistic Adam) is a variant of Adam adding an "optimistic" term suitable for adversarial training.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.
  • Decay of momentums (β::Tuple): Exponential decay for the first (β1) and the second (β2) momentum estimate.
  • Machine epsilon (ϵ): Constant to prevent division by zero (no need to change default)
source
Optimisers.AdaGradType
AdaGrad(η = 0.1, ϵ = 1e-8)

AdaGrad optimizer. It has parameter specific learning rates based on how frequently it is updated. Parameters don't need tuning.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.
  • Machine epsilon (ϵ): Constant to prevent division by zero (no need to change default)
source
Optimisers.AdaDeltaType
AdaDelta(ρ = 0.9, ϵ = 1e-8)

AdaDelta is a version of AdaGrad adapting its learning rate based on a window of past gradient updates. Parameters don't need tuning.

Parameters

  • Rho (ρ): Factor by which the gradient is decayed at each time step.
  • Machine epsilon (ϵ): Constant to prevent division by zero (no need to change default)
source
Optimisers.AMSGradType
AMSGrad(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)

The AMSGrad version of the Adam optimiser. Parameters don't need tuning.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.
  • Decay of momentums (β::Tuple): Exponential decay for the first (β1) and the second (β2) momentum estimate.
  • Machine epsilon (ϵ): Constant to prevent division by zero (no need to change default)
source
Optimisers.NAdamType
NAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8)

NAdam is a Nesterov variant of Adam. Parameters don't need tuning.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.
  • Decay of momentums (β::Tuple): Exponential decay for the first (β1) and the second (β2) momentum estimate.
  • Machine epsilon (ϵ): Constant to prevent division by zero (no need to change default)
source
Optimisers.AdamWFunction
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
AdamW(; [eta, beta, lambda, epsilon])

AdamW is a variant of Adam fixing (as in repairing) its weight decay regularization. Implemented as an OptimiserChain of Adam and WeightDecay`.

Parameters

  • Learning rate (η == eta): Amount by which gradients are discounted before updating the weights.
  • Decay of momentums (β::Tuple == beta): Exponential decay for the first (β1) and the second (β2) momentum estimate.
  • Weight decay (λ == lambda): Controls the strength of $L_2$ regularisation.
  • Machine epsilon (ϵ == epsilon): Constant to prevent division by zero (no need to change default)
source
Optimisers.AdaBeliefType
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)

The AdaBelief optimiser is a variant of the well-known Adam optimiser.

Parameters

  • Learning rate (η): Amount by which gradients are discounted before updating the weights.
  • Decay of momentums (β::Tuple): Exponential decay for the first (β1) and the second (β2) momentum estimate.
  • Machine epsilon (ϵ::Float32): Constant to prevent division by zero (no need to change default)
source
Optimisers.LionType
Lion(η = 0.001, β = (0.9, 0.999))

Lion optimiser.

Parameters

  • Learning rate (η): Magnitude by which gradients are updating the weights.
  • Decay of momentums (β::Tuple): Exponential decay for the first (β1) and the second (β2) momentum estimate.
source

In addition to the main course, you may wish to order some of these condiments:

Optimisers.AccumGradType
AccumGrad(n::Int)

A rule constructed OptimiserChain(AccumGrad(n), Rule()) will accumulate for n steps, before applying Rule to the mean of these n gradients.

This is useful for training with effective batch sizes too large for the available memory. Instead of computing the gradient for batch size b at once, compute it for size b/n and accumulate n such gradients.

Example

julia> m = (x=[1f0], y=[2f0]);

julia> r = OptimiserChain(AccumGrad(2), WeightDecay(0.01), Descent(0.1));

julia> s = Optimisers.setup(r, m);

julia> Optimisers.update!(s, m, (x=[33], y=[0]));

julia> m  # model not yet changed
(x = Float32[1.0], y = Float32[2.0])

julia> Optimisers.update!(s, m, (x=[0], y=[444]));

julia> m  # n=2 gradients applied at once
(x = Float32[-0.651], y = Float32[-20.202002])
source
Optimisers.ClipNormType
ClipNorm(ω = 10, p = 2; throw = true)

Scales any gradient array for which norm(dx, p) > ω to stay at this threshold (unless p==0).

Throws an error if the norm is infinite or NaN, which you can turn off with throw = false.

Typically composed with other rules using OptimiserChain.

See also ClipGrad.

source
Optimisers.SignDecayType
SignDecay(λ = 1e-3)

Implements $L_1$ regularisation, also known as LASSO regression, when composed with other rules as the first transformation in an OptimiserChain.

It does this by adding λ .* sign(x) to the gradient. This is equivalent to adding λ * sum(abs, x) == λ * norm(x, 1) to the loss.

See also [WeightDecay] for $L_2$ normalisation. They can be used together: OptimiserChain(SignDecay(0.012), WeightDecay(0.034), Adam()) is equivalent to adding 0.012 * norm(x, 1) + 0.017 * norm(x, 2)^2 to the loss function.

Parameters

  • Penalty (λ ≥ 0): Controls the strength of the regularisation.
source
Optimisers.WeightDecayType
WeightDecay(λ = 5e-4)

Implements $L_2$ regularisation, also known as ridge regression, when composed with other rules as the first transformation in an OptimiserChain.

It does this by adding λ .* x to the gradient. This is equivalent to adding λ/2 * sum(abs2, x) == λ/2 * norm(x)^2 to the loss.

See also [SignDecay] for $L_1$ normalisation.

Parameters

  • Penalty (λ ≥ 0): Controls the strength of the regularisation.
source
Optimisers.OptimiserChainType
OptimiserChain(opts...)

Compose a sequence of optimisers so that each opt in opts updates the gradient, in the order specified.

With an empty sequence, OptimiserChain() is the identity, so update! will subtract the full gradient from the parameters. This is equivalent to Descent(1).

Example

julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));

julia> m = (zeros(3),);

julia> s = Optimisers.setup(o, m)
(Leaf(OptimiserChain(ClipGrad(1.0), Descent(0.1)), (nothing, nothing)),)

julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2]  # clips before discounting
([-0.03, -0.1, -0.1],)
source

Model Interface

Optimisers.setupFunction
Optimisers.setup(rule, model) -> state_tree

Initialises the given optimiser for every trainable parameter within the model. Returns a tree of the relevant states, which must be passed to update or update!.

Example

julia> m = (x = rand(3), y = (true, false), z = tanh);

julia> Optimisers.setup(Momentum(), m)  # same field names as m
(x = Leaf(Momentum(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())

The recursion into structures uses Functors.jl, and any new structs containing parameters need to be marked with Functors.@functor before use. See the Flux docs for more about this.

julia> struct Layer; mat; fun; end

julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);

julia> Optimisers.setup(Momentum(), model)  # new struct is by default ignored
(lay = (), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))

julia> destructure(model)
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))

julia> using Functors; @functor Layer  # annotate this type as containing parameters

julia> Optimisers.setup(Momentum(), model)
(lay = (mat = Leaf(Momentum(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))

julia> destructure(model)
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
source
Optimisers.updateFunction
Optimisers.update(tree, model, gradient) -> (tree, model)

Uses the optimiser and the gradient to change the trainable parameters in the model. Returns the improved model, and the optimiser states needed for the next update. The initial tree of states comes from setup.

See also update!, which will be faster for models of ordinary Arrays or CuArrays.

Example

julia> m = (x = Float32[1,2,3], y = tanh);

julia> t = Optimisers.setup(Descent(0.1), m)
(x = Leaf(Descent(0.1), nothing), y = ())

julia> g = (x = [1,1,1], y = nothing);  # fake gradient

julia> Optimisers.update(t, m, g)
((x = Leaf(Descent(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
source
Optimisers.update!Function
Optimisers.update!(tree, model, gradient) -> (tree, model)

Uses the optimiser and the gradient to change the trainable parameters in the model. Returns the improved model, and the optimiser states needed for the next update. The initial tree of states comes from setup.

This is used in exactly the same manner as update, but because it may mutate arrays within the old model (and the old state), it will be faster for models of ordinary Arrays or CuArrays. However, you should not rely on the old model being fully updated but rather use the returned model. (The original state tree is always mutated, as each Leaf is mutable.)

Example

julia> using StaticArrays, Zygote, Optimisers

julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]);  # partly mutable model

julia> t = Optimisers.setup(Momentum(1/30, 0.9), m)  # tree of states
(x = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]))

julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1]  # structural gradient
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])

julia> t2, m2 = Optimisers.update!(t, m, g);

julia> m2  # after update or update!, this is the new model
(x = Float32[0.6666666, 1.5333333], y = Float32[3.6666667, 4.5333333])

julia> m2.x === m.x  # update! has re-used this array, for efficiency
true

julia> m  # original should be discarded, may be mutated but no guarantee
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])

julia> t == t2  # original state tree is guaranteed to be mutated
true
source
Optimisers.adjust!Function
Optimisers.adjust!(tree, η)

Alters the state tree = setup(rule, model) to change the parameters of the optimisation rule, without destroying its stored state. Typically used mid-way through training.

Can be applied to part of a model, by acting only on the corresponding part of the state tree.

To change just the learning rate, provide a number η::Real.

Example

julia> m = (vec = rand(Float32, 2), fun = sin);

julia> st = Optimisers.setup(Nesterov(), m)  # stored momentum is initialised to zero
(vec = Leaf(Nesterov(0.001, 0.9), Float32[0.0, 0.0]), fun = ())

julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing));  # with fake gradient

julia> st
(vec = Leaf(Nesterov(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())

julia> Optimisers.adjust!(st, 0.123)  # change learning rate, stored momentum untouched

julia> st
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())

To change other parameters, adjust! also accepts keyword arguments matching the field names of the optimisation rule's type.

julia> fieldnames(Adam)
(:eta, :beta, :epsilon)

julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
(vec = Leaf(OptimiserChain(ClipGrad(10.0), Adam(0.001, (0.9, 0.999), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())

julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1)  # delta acts on ClipGrad
(vec = Leaf(OptimiserChain(ClipGrad(11.1), Adam(0.001, (0.777, 0.909), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())

julia> Optimisers.adjust(st; beta = "no such field")  # silently ignored!
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
source
Optimisers.freeze!Function
Optimisers.freeze!(tree)

Temporarily alters the state tree = setup(rule, model) so that parameters will not be updated. Un-done by thaw!.

Can be applied to the state corresponding to only part of a model, for instance with model::Chain, to freeze model.layers[1] you should call freeze!(tree.layers[1]).

Example

julia> m = (x = ([1.0], 2.0), y = [3.0]);

julia> s = Optimisers.setup(Momentum(), m);

julia> Optimisers.freeze!(s.x)

julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi]));  # with fake gradient

julia> m
(x = ([1.0], 2.0), y = [-0.14159265358979312])

julia> s
(x = (Leaf(Momentum(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum(0.01, 0.9), [3.14159]))

julia> Optimisers.thaw!(s)

julia> s.x
(Leaf(Momentum(0.01, 0.9), [0.0]), ())
source
Optimisers.thaw!Function
Optimisers.thaw!(tree)

The reverse of freeze!. Applies to all parameters, mutating every Leaf(rule, state, frozen = true) to Leaf(rule, state, frozen = false).

source

Calling Functors.@functor on your model's layer types by default causes these functions to recurse into all children, and ultimately optimise all isnumeric leaf nodes. To further restrict this by ignoring some fields of a layer type, define trainable:

Optimisers.trainableFunction
trainable(x::Layer) -> NamedTuple

This may be overloaded to make optimisers ignore some fields of every Layer, which would otherwise contain trainable parameters.

Warning

This is very rarely required. Fields of struct Layer which contain functions, or integers like sizes, are always ignored anyway. Overloading trainable is only necessary when some arrays of numbers are to be optimised, and some arrays of numbers are not.

The default is Functors.children(x), usually a NamedTuple of all fields, and trainable(x) must contain a subset of these.

source
Optimisers.isnumericFunction
isnumeric(x) -> Bool

Returns true on any parameter to be adjusted by Optimisers.jl, namely arrays of non-integer numbers. Returns false on all other types.

Requires also that Functors.isleaf(x) == true, to focus on e.g. the parent of a transposed matrix, not the wrapper.

source
Optimisers.maywriteFunction
maywrite(x) -> Bool

Should return true if we are completely sure that update! can write new values into x. Otherwise false, indicating a non-mutating path. For now, simply x isa DenseArray allowing Array, CuArray, etc.

source

Such restrictions are also obeyed by this function for flattening a model:

Optimisers.destructureFunction
destructure(model) -> vector, reconstructor

Copies all trainable, isnumeric parameters in the model to a vector, and returns also a function which reverses this transformation. Differentiable.

Example

julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))

julia> re([3, 5, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))

If model contains various number types, they are promoted to make vector, and are usually restored by Restructure. Such restoration follows the rules of ChainRulesCore.ProjectTo, and thus will restore floating point precision, but will permit more exotic numbers like ForwardDiff.Dual.

If model contains only GPU arrays, then vector will also live on the GPU. At present, a mixture of GPU and ordinary CPU arrays is undefined behaviour.

source
Optimisers.RestructureType
Restructure(Model, ..., length)

This is what destructure returns, and re(p) will re-build the model with new parameters from vector p. If the model is callable, then re(x, p) == re(p)(x).

Example

julia> using Flux, Optimisers

julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid))
([1, 3, 2, 4, 0, 0], Restructure(Dense, ..., 6))

julia> m = re(-4:1)
Dense(2, 2, σ)      # 6 parameters

julia> m([0.2, 0.3]) ≈ re([0.2, 0.3], -4:1)
true
source
Optimisers.trainablesFunction
trainables(x, path = false)

Return an iterable over all the trainable parameters in x, that is all the numerical arrays (see isnumeric) which are reachable through trainable.

Parameters appearing multiple times in the model (tied weights) will be present only once in the output.

If path = false, the output is a list of numerical arrays.

If path = true, the output is a list of (KeyPath, AbstractArray) pairs, where KeyPath is a type representing the path to the array in the original structure.

See also destructure for a similar operation that returns a single flat vector instead.

Examples

julia> struct MyLayer
         w
         b
       end

julia> Functors.@functor MyLayer

julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example

julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]);

julia> trainables(x)
1-element Vector{AbstractArray}:
 [1.0, 2.0, 3.0]

 julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);

 julia> trainables(x) # collects nested parameters
 2-element Vector{AbstractArray}:
 [1.0, 2.0]
 [3.0]
julia> x = (a = [1.0,2.0], b = (Dict("c" => [3.0, 4.0], "d" => 5.0), [6.0,7.0]));

julia> for (kp, y) in trainables(x, path = true)
           println(kp, " => ", y)
       end
KeyPath(:a,) => [1.0, 2.0]
KeyPath(:b, 1, "c") => [3.0, 4.0]
KeyPath(:b, 2) => [6.0, 7.0]

julia> getkeypath(x, KeyPath(:b, 1, "c"))
2-element Vector{Float64}:
 3.0
 4.0
source

Rule Definition

Optimisers.apply!Function
Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient)

This defines the action of any optimisation rule. It should return the modified gradient which will be subtracted from the parameters, and the updated state (if any) for use at the next iteration, as a tuple (state, gradient).

For efficiency it is free to mutate the old state, but only what is returned will be used. Ideally this should check maywrite(x), which the built-in rules do via @...

The initial state is init(rule::RuleType, parameters).

Example

julia> Optimisers.init(Descent(0.1), Float32[1,2,3]) === nothing
true

julia> Optimisers.apply!(Descent(0.1), nothing, Float32[1,2,3], [4,5,6])
(nothing, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(*, ([4, 5, 6], 0.1f0)))
source
Optimisers.initFunction
Optimisers.init(rule::RuleType, parameters) -> state

Sets up the initial state for a given optimisation rule, and an array of parameters. This and apply! are the two functions which any new optimisation rule must define.

Examples

julia> Optimisers.init(Descent(), Float32[1,2,3])  # is `nothing`

julia> Optimisers.init(Momentum(), [1.0, 2.0])
2-element Vector{Float64}:
 0.0
 0.0
source
Optimisers.@..Macro
@.. x = y + z

Sometimes in-place broadcasting macro, for use in apply! rules. If maywrite(x) then it is just @. x = rhs, but if not, it becomes x = @. rhs.

source
Optimisers.@lazyMacro
x = @lazy y + z

Lazy broadcasting macro, for use in apply! rules. It broadcasts like @. but does not materialise, returning a Broadcasted object for later use. Beware that mutation of arguments will affect the result, and that if it is used in two places, work will be done twice.

source
Optimisers.adjustMethod
Optimisers.adjust(rule::RuleType, η::Real) -> rule

If a new optimisation rule has a learning rate which is not stored in field rule.eta, then you may should add a method to adjust. (But simpler to just use the standard name.)

source
Optimisers.@defMacro

@def struct Rule; eta = 0.1; beta = (0.7, 0.8); end

Helper macro for defining rules with default values. The types of the literal values are used in the struct, like this:

struct Rule
  eta::Float64
  beta::Tuple{Float64, Float64}
  Rule(eta, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
  Rule(; eta = 0.1, beta = (0.7, 0.8)) = Rule(eta, beta)
end

Any field called eta is assumed to be a learning rate, and cannot be negative.

source

KeyPath

A KeyPath is a sequence of keys that can be used to access a value within a nested structure. It is defined in Functors.jl and re-exported by Optimisers.jl here for convenience.

Functors.KeyPathType
KeyPath(keys...)

A type for representing a path of keys to a value in a nested structure. Can be constructed with a sequence of keys, or by concatenating other KeyPaths. Keys can be of type Symbol, String, or Int.

For custom types, access through symbol keys is assumed to be done with getproperty. For consistency, the method Base.propertynames is used to get the viable property names.

For string and integer keys instead, the access is done with getindex.

See also getkeypath, haskeypath.

Examples

julia> kp = KeyPath(:b, 3)
KeyPath(:b, 3)

julia> KeyPath(:a, kp, :c, 4) # construct mixing keys and keypaths
KeyPath(:a, :b, 3, :c, 4)

julia> struct T
           a
           b
       end

julia> function Base.getproperty(x::T, k::Symbol)
            if k in fieldnames(T)
                return getfield(x, k)
            elseif k === :ab
                return "ab"
            else        
                error()
            end
        end;

julia> Base.propertynames(::T) = (:a, :b, :ab);

julia> x = T(3, Dict(:c => 4, :d => 5));

julia> getkeypath(x, KeyPath(:ab)) # equivalent to x.ab
"ab"

julia> getkeypath(x, KeyPath(:b, :c)) # equivalent to (x.b)[:c]
4
source
Functors.haskeypathFunction
haskeypath(x, kp::KeyPath)

Return true if x has a value at the path kp.

See also KeyPath and getkeypath.

Examples

julia> x = Dict(:a => 3, :b => Dict(:c => 4, "d" => [5, 6, 7]))
Dict{Any,Any} with 2 entries:
  :a => 3
  :b => Dict{Any,Any}(:c=>4,"d"=>[5, 6, 7])

julia> haskeypath(x, KeyPath(:a))
true

julia> haskeypath(x, KeyPath(:b, "d", 1))
true

julia> haskeypath(x, KeyPath(:b, "d", 4))
false
source
Functors.getkeypathFunction
getkeypath(x, kp::KeyPath)

Return the value in x at the path kp.

See also KeyPath and haskeypath.

Examples

julia> x = Dict(:a => 3, :b => Dict(:c => 4, "d" => [5, 6, 7]))
Dict{Symbol, Any} with 2 entries:
  :a => 3
  :b => Dict{Any, Any}(:c=>4, "d"=>[5, 6, 7])

julia> getkeypath(x, KeyPath(:b, "d", 2))
6
source