# Optimisers.jl

## An optimisation rule

A new optimiser must overload two functions, `apply!`

and `init`

. These act on one array of parameters:

```
# Define a container to hold any optimiser specific parameters (if any):
struct DecayDescent <: Optimisers.AbstractRule
eta::Float64
end
# Define an `apply!` rule which encodes how the gradients will be used to
# update the parameters:
function Optimisers.apply!(o::DecayDescent, state, x, x̄)
T = eltype(x)
newx̄ = T(o.eta / √state) .* x̄
nextstate = state + 1
return nextstate, newx̄
end
# Define the function which sets up the initial state (if any):
Optimisers.init(o::DecayDescent, x::AbstractArray) = 1
```

The parameters will be immediately updated to `x .- newx̄`

, while `nextstate`

is caried to the next iteration.

Notice that the state is handled separately from the optimiser itself. This is a key design principle and allows users to manage their own state explicitly. It of course also makes it easier to store the state.

## Usage with Flux.jl

To apply such an optimiser to a whole model, `setup`

builds a tree containing any initial state for every trainable array. Then at each step, `update`

uses this and the gradient to adjust the model:

```
using Flux, Metalhead, Zygote, Optimisers
model = Metalhead.ResNet(18) |> gpu # define a model to train
image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data
@show sum(model(image)); # dummy loss function
rule = Optimisers.Adam() # use the Adam optimiser with its default settings
state_tree = Optimisers.setup(rule, model); # initialise this optimiser's momentum etc.
∇model, _ = gradient(model, image) do m, x # calculate the gradients
sum(m(x))
end;
state_tree, model = Optimisers.update(state_tree, model, ∇model);
@show sum(model(image)); # reduced
```

Notice that a completely new instance of the model is returned. Internally, this is handled by Functors.jl, where we do a walk over the tree formed by the model and update the parameters using the gradients.

There is also `Optimisers.update!`

which similarly returns a new model, but is free to mutate arrays within the old one for efficiency. (The method of `apply!`

above is likewise free to mutate arrays within its state; they are defensively copied when this rule is used with `update`

.) For `Adam()`

, there are two momenta per parameter, thus `state`

is about twice the size of `model`

:

```
Base.summarysize(model) / 1024^2 # about 45MB
Base.summarysize(state) / 1024^2 # about 90MB
```

Optimisers.jl does not depend on any one automatic differentiation package, but for now the most likely source of gradients is Zygote.jl. Note that `update`

always wants the gradient from Zygote's "explicit" mode, as shown above. This `∇model`

is another tree structure, rather than the dictionary-like object from Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))`

– see Zygote's documentation for more about this difference.

## Usage with Lux.jl

The main design difference of Lux from Flux is that the tree of parameters is separate from the layer structure. It is these parameters which `setup`

and `update`

need to know about.

Lux describes this separation of parameter storage from model description as "explicit" parameters. Beware that it has nothing to do with Zygote's notion of "explicit" gradients. (If the same model is written in Flux and Lux, `∇model`

above and `∇params`

below will be nearly identical trees of nested `NamedTuple`

s.)

```
using Lux, Boltz, Zygote, Optimisers
lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data
y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model
@show sum(y); # initial dummy loss
rule = Optimisers.Adam()
opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters
(loss, lux_state), back = Zygote.pullback(params, images) do p, x
y, st = Lux.apply(lux_model, x, p, lux_state)
sum(y), st # return both the loss, and the updated lux_state
end;
∇params, _ = back((one.(loss), nothing)); # gradient of only the loss, with respect to parameter tree
loss == sum(y) # not yet changed
opt_state, params = Optimisers.update!(opt_state, params, ∇params);
y, lux_state = Lux.apply(lux_model, images, params, lux_state);
@show sum(y); # now reduced
```

Besides the parameters stored in `params`

and gradually optimised, any other model state is stored in `lux_state`

, and updated by `Lux.apply`

. (In this example, BatchNorm has state.) This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.

```
Base.summarysize(lux_model) / 1024 # just 2KB
Base.summarysize(params) / 1024^2 # about 45MB, same as Flux model
Base.summarysize(lux_state) / 1024 # 40KB
Base.summarysize(opt_state) / 1024^2 # about 90MB, with Adam
```

If you are certain there is no model state, then the gradient calculation can be simplified to use `Zygote.gradient`

instead of `Zygote.pullback`

:

```
∇params, _ = gradient(params, images) do p, x
y, _ = Lux.apply(lux_model, x, p, lux_state) # discards new lux_state
sum(y)
end;
```

## Non-`trainable`

Parameters

Optimisers.jl uses Functors.jl to walk the `struct`

s making up the model, for which they must be annotated `@functor Type`

. By default optimisation will alter all `isnumeric`

arrays.

If some arrays of a particular layer should not be treated this way, you can define a method for `trainable`

```
struct Layer{T}
alpha::T
beta::T
length::Int
end
Layer(n::Int) = Layer(randn(n), zeros(n), n)
Functors.@functor Layer
# Both array fields will be, for example, moved to the GPU:
Functors.children(Layer(3)) # (alpha = [...], beta = [...], length)
Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of children
# Only the first field will be optimised:
st = Optimisers.setup(DecayDescent(0.1), Layer(3))
```

## Frozen Parameters

To temporarily prevent training from affecting some parameters, use freeze! and `thaw!`

. They work by mutating all `Leaf`

s of the state tree, or part of it.

```
using Flux, Optimisers
x = randn(Float32, 28, 28, 1, 1);
net = @autosize (size(x)...,) Chain(
Conv((3, 3), 1 => 3, stride=2, bias=false), Flux.flatten, Dense(_ => 2, relu),
)
opt = Optimisers.setup(Optimisers.Momentum(), net);
net.layers[3] isa Dense # now freeze this layer's parameters:
Optimisers.freeze!(opt.layers[3])
opt.layers[3].bias # confirm: Leaf(Momentum(...), [0.0, 0.0], frozen = true)
Optimisers.update!(opt, net, gradient(m -> sum(m(x)), net)...);
net.layers[3].bias # stil zero, and its momentum is too:
Optimisers.thaw!(opt)
opt.layers[3].bias # Leaf(Momentum(...), [0.0, 0.0])
```

## Adjusting Hyperparameters

To change the learning rate during training, use `adjust!`

. This works much like `freeze!`

by mutating the state tree, or part of it, without discarding the momenta. For the Flux model from just above:

```
Optimisers.adjust!(opt, 0.03) # change η for the whole model...
Optimisers.adjust!(opt.layers[3], 0.04) # ... or just for one layer.
```

To change other fields of the optimisation rule, it accepts keyword arguments:

```
Momentum |> fieldnames # (:eta, :rho)
Optimisers.adjust!(opt, rho = 0.95) # change ρ for the whole model.
```

## Tied Parameters

If the same array appears twice (or more) in the model, Functors.jl should recognise this. Within Optimisers.jl, `setup`

will initialise once, and use the same `Leaf`

for both parameters. Then `update`

will accumulate the gradient from both, and the updated model returned will have the tie maintained.

```
using Flux, Optimisers
enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10));
dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh));
model = Chain(; enc, dec)
st = Optimisers.setup(Optimisers.Adam(), model);
st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
```

This identification relies on `===`

, and will work for ordinary `Array`

s and `CuArray`

s. It will not at present work for `reshape`

d arrays, nor for immutable arrays such as those from StaticArrays.jl.

## Obtaining a flat parameter vector

Instead of a nested tree-like structure, sometimes is is convenient to have all the parameters as one simple vector. Optimisers.jl contains a function `destructure`

which creates this vector, and also creates way to re-build the original structure with new parameters. Both flattening and re-building may be used within `gradient`

calls.

An example with Flux's `model`

:

```
using ForwardDiff # an example of a package which only likes one array
model = Chain( # much smaller model example, as ForwardDiff is a slow algorithm here
Conv((3, 3), 3 => 5, pad=1, bias=false),
BatchNorm(5, relu),
Conv((3, 3), 5 => 3, stride=16),
)
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));
flat, re = destructure(model)
st = Optimisers.setup(rule, flat) # state is just one Leaf now
∇flat = ForwardDiff.gradient(flat) do v
m = re(v) # rebuild a new object like model
sum(m(image)) # call that as before
end
st, flat = Optimisers.update(st, flat, ∇flat)
@show sum(re(flat)(image));
```

Here `flat`

contains only the 283 trainable parameters, while the non-trainable ones are preserved inside `re`

, an object of type `Restructure`

. When defining new layers, these can be specified if necessary by overloading `trainable`

. By default, all numeric arrays visible to Functors.jl are assumed to contain trainable parameters. Tied parameters (arrays appearing in different layers) are included only once in `flat`

.

Lux stores only the trainable parameters in `params`

. This can also be flattened to a plain `Vector`

in the same way:

```
params, lux_state = Lux.setup(Random.default_rng(), lux_model);
flat, re = destructure(params)
∇flat = ForwardDiff.gradient(flat) do v
p = re(v) # rebuild an object like params
y, _ = Lux.apply(lux_model, images, p, lux_state)
sum(y)
end
```

## Collecting all trainable parameters

Sometimes it is useful to collect all trainable parameters in a model, similarly to what `destructure`

does but without concatenating the arrays into a flat vector. This is done by `trainables`

, which returns a list of arrays:

```
julia> using Flux, Optimisers
julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2));
julia> trainables(model)
6-element Vector{AbstractArray}:
Float32[0.5756773 -0.1975264; 0.4723181 -0.7546912; -0.91631395 0.07392061]
Float32[0.0, 0.0, 0.0]
Float32[0.0, 0.0, 0.0]
Float32[1.0, 1.0, 1.0]
Float32[-0.8764882 0.40812716 0.1919528; -0.9123545 -0.4462516 0.6751252]
Float32[0.0, 0.0]
julia> l2reg(model) = sum([sum(abs2,p) for p in trainables(model)]);
julia> g = gradient(l2reg, model)[1];
```

Notice that the `BatchNorm`

layer has two trainable parameters, `γ`

and `β`

, which are included in the list, while the `μ`

and `σ²`

buffers are not.