# Training a Flux Model

Training refers to the process of slowly adjusting the parameters of a model to make it work better. Besides the model itself, we will need three things:

- An
*objective function*that evaluates how well a model is doing on some input. - An
*optimisation rule*which describes how the model's parameters should be adjusted. - Some
*training data*to use as the input during this process.

Usually the training data is some collection of examples (or batches of examples) which are handled one-by-one. One *epoch* of training means that each example is used once, something like this:

```
# Initialise the optimiser for this model:
opt_state = Flux.setup(rule, model)
for data in train_set
# Unpack this element (for supervised training):
input, label = data
# Calculate the gradient of the objective
# with respect to the parameters within the model:
grads = Flux.gradient(model) do m
result = m(input)
loss(result, label)
end
# Update the parameters so as to reduce the objective,
# according the chosen optimisation rule:
Flux.update!(opt_state, model, grads[1])
end
```

This loop can also be written using the function `train!`

, but it's helpful to understand the pieces first:

```
train!(model, train_set, opt_state) do m, x, y
loss(m(x), y)
end
```

## Model Gradients

Fist recall from the section on taking gradients that `Flux.gradient(f, a, b)`

always calls `f(a, b)`

, and returns a tuple `(∂f_∂a, ∂f_∂b)`

. In the code above, the function `f`

passed to `gradient`

is an anonymous function with one argument, created by the `do`

block, hence `grads`

is a tuple with one element. Instead of a `do`

block, we could have written:

`grads = Flux.gradient(m -> loss(m(input), label), model)`

Since the model is some nested set of layers, `grads[1]`

is a similarly nested set of `NamedTuple`

s, ultimately containing gradient components. If (for example) `θ = model.layers[1].weight[2,3]`

is one scalar parameter, an entry in a matrix of weights, then the derivative of the loss with respect to it is `∂f_∂θ = grads[1].layers[1].weight[2,3]`

.

It is important that the execution of the model takes place inside the call to `gradient`

, in order for the influence of the model's parameters to be observed by Zygote.

It is also important that every `update!`

step receives a newly computed gradient, as it will change whenever the model's parameters are changed, and for each new data point.

## Loss Functions

The objective function must return a number representing how far the model is from the desired result. This is termed the *loss* of the model.

This number can be produced by any ordinary Julia code, but this must be executed within the call to `gradient`

. For instance, we could define a function

`loss(y_hat, y) = sum((y_hat .- y).^2)`

or write this directly inside the `do`

block above. Many commonly used functions, like `mse`

for mean-squared error or `crossentropy`

for cross-entropy loss, are available from the `Flux.Losses`

module.

## Optimisation Rules

The simplest kind of optimisation using the gradient is termed *gradient descent* (or sometimes *stochastic gradient descent* when, as here, it is not applied to the entire dataset at once).

Gradient descent needs a *learning rate* which is a small number describing how fast to walk downhill, usually written as the Greek letter "eta", `η`

. This is often described as a *hyperparameter*, to distinguish it from the parameters which are being updated `θ = θ - η * ∂loss_∂θ`

. We want to update all the parameters in the model, like this:

```
η = 0.01 # learning rate
# For each parameter array, update
# according to the corresponding gradient:
fmap(model, grads[1]) do p, g
p .= p .- η .* g
end
```

A slightly more refined version of this loop to update all the parameters is wrapped up as a function `update!`

`(opt_state, model, grads[1])`

. And the learning rate is the only thing stored in the `Descent`

struct.

However, there are many other optimisation rules, which adjust the step size and direction in various clever ways. Most require some memory of the gradients from earlier steps, rather than always walking straight downhill – `Momentum`

is the simplest. The function `setup`

creates the necessary storage for this, for a particular model. It should be called once, before training, and returns a tree-like object which is the first argument of `update!`

. Like this:

```
# Initialise momentum
opt_state = Flux.setup(Momentum(0.01, 0.9), model)
for data in train_set
grads = [...]
# Update both model parameters and optimiser state:
Flux.update!(opt_state, model, grads[1])
end
```

Many commonly-used optimisation rules, such as `Adam`

, are built-in. These are listed on the optimisers page.

This `setup`

makes another tree-like structure. Old versions of Flux did not do this, and instead stored a dictionary-like structure within the optimiser `Adam(0.001)`

. This was initialised on first use of the version of `update!`

for "implicit" parameters.

## Datasets & Batches

The loop above iterates through `train_set`

, expecting at each step a tuple `(input, label)`

. The very simplest such object is a vector of tuples, such as this:

```
x = randn(28, 28)
y = rand(10)
data = [(x, y)]
```

or `data = [(x, y), (x, y), (x, y)]`

for the same values three times.

Very often, the initial data is large arrays which you need to slice into examples. To produce one iterator of pairs `(x, y)`

, you might want `zip`

:

```
X = rand(28, 28, 60_000); # many images, each 28 × 28
Y = rand(10, 60_000)
data = zip(eachslice(X; dims=3), eachcol(Y))
first(data) isa Tuple{AbstractMatrix, AbstractVector} # true
```

Here each iteration will use one matrix `x`

(an image, perhaps) and one vector `y`

. It is very common to instead train on *batches* of such inputs (or *mini-batches*, the two words mean the same thing) both for efficiency and for better results. This can be easily done using the `DataLoader`

:

```
data = Flux.DataLoader((X, Y), batchsize=32)
x1, y1 = first(data)
size(x1) == (28, 28, 32)
length(data) == 1875 === 60_000 ÷ 32
```

Flux's layers are set up to accept such a batch of input data, and the convolutional layers such as `Conv`

require it. The batch index is always the last dimension.

## Training Loops

Simple training loops like the one above can be written compactly using the `train!`

function. Including `setup`

, this reads:

```
opt_state = Flux.setup(Adam(), model)
for epoch in 1:100
Flux.train!(model, train_set, opt_state) do m, x, y
loss(m(x), y)
end
end
```

Or explicitly writing the anonymous function which this `do`

block creates, `train!((m,x,y) -> loss(m(x),y), model, train_set, opt_state)`

is exactly equivalent.

Real training loops often need more flexibility, and the best way to do this is just to write the loop. This is ordinary Julia code, without any need to work through some callback API. Here is an example, in which it may be helpful to note:

- The function
`withgradient`

is like`gradient`

but also returns the value of the function, for logging or diagnostic use. - Logging or printing is best done outside of the
`gradient`

call, as there is no need to differentiate these commands. - To use
`result`

for logging purposes, you could change the`do`

block to end with`return my_loss(result, label), result`

, i.e. make the function passed to`withgradient`

return a tuple. The first element is always the loss. - Julia's
`break`

and`continue`

keywords let you exit from parts of the loop.

```
opt_state = Flux.setup(Adam(), model)
my_log = []
for epoch in 1:100
losses = Float32[]
for (i, data) in enumerate(train_set)
input, label = data
val, grads = Flux.withgradient(model) do m
# Any code inside here is differentiated.
# Evaluation of the model and loss must be inside!
result = m(input)
my_loss(result, label)
end
# Save the loss from the forward pass. (Done outside of gradient.)
push!(losses, val)
# Detect loss of Inf or NaN. Print a warning, and then skip update!
if !isfinite(val)
@warn "loss is $val on item $i" epoch
continue
end
Flux.update!(opt_state, model, grads[1])
end
# Compute some accuracy, and save details as a NamedTuple
acc = my_accuracy(model, train_set)
push!(my_log, (; acc, losses))
# Stop training when some criterion is reached
if acc > 0.95
println("stopping after $epoch epochs")
break
end
end
```

## Regularisation

The term *regularisation* covers a wide variety of techniques aiming to improve the result of training. This is often done to avoid overfitting.

Some of these can be implemented by simply modifying the loss function. *L₂ regularisation* (sometimes called ridge regression) adds to the loss a penalty proportional to `θ^2`

for every scalar parameter. A very simple model could be implemented as follows:

```
grads = Flux.gradient(densemodel) do m
result = m(input)
penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2
my_loss(result, label) + 0.42f0 * penalty
end
```

Accessing each individual parameter array by hand won't work well for large models. Instead, we can use `Flux.trainables`

to collect all of them, and then apply a function to each one, and sum the result:

```
pen_l2(x::AbstractArray) = sum(abs2, x)/2
grads = Flux.gradient(model) do m
result = m(input)
penalty = sum(pen_l2, Flux.trainables(m))
my_loss(result, label) + 0.42f0 * penalty
end
```

However, the gradient of this penalty term is very simple: It is proportional to the original weights. So there is a simpler way to implement exactly the same thing, by modifying the optimiser instead of the loss function. This is done by replacing this:

`opt_state = Flux.setup(Adam(0.1), model)`

with this:

`decay_opt_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model)`

Flux's optimisers are really modifications applied to the gradient before using it to update the parameters, and `OptimiserChain`

applies two such modifications. The first, `WeightDecay`

adds `0.42`

times the original parameter to the gradient, matching the gradient of the penalty above (with the same, unrealistically large, constant). After that, in either case, `Adam`

computes the final update.

The same trick works for *L₁ regularisation* (also called Lasso), where the penalty is `pen_l1(x::AbstractArray) = sum(abs, x)`

instead. This is implemented by `SignDecay(0.42)`

.

The same `OptimiserChain`

mechanism can be used for other purposes, such as gradient clipping with `ClipGrad`

or `ClipNorm`

.

Besides L1 / L2 / weight decay, another common and quite different kind of regularisation is provided by the `Dropout`

layer. This turns off some outputs of the previous layer during training. It should switch automatically, but see `trainmode!`

/ `testmode!`

to manually enable or disable this layer.

## Learning Rate Schedules

Finer control of training, you may wish to alter the learning rate mid-way through training. This can be done with `adjust!`

, like this:

```
opt_state = Flux.setup(Adam(0.1), model) # initialise once
for epoch in 1:1000
train!([...], state) # Train with η = 0.1 for first 100,
if epoch == 100 # then change to use η = 0.01 for the rest.
Flux.adjust!(opt_state, 0.01)
end
end
```

Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt_state, beta = (0.8, 0.99))`

. And such modifications can be applied to just one part of the model. For instance, this sets a different learning rate for the encoder and the decoder:

```
# Consider some model with two parts:
bimodel = Chain(enc = [...], dec = [...])
# This returns a tree whose structure matches the model:
opt_state = Flux.setup(Adam(0.02), bimodel)
# Adjust the learning rate to be used for bimodel.layers.enc
Flux.adjust!(opt_state.layers.enc, 0.03)
```

## Freezing layer parameters

To completely disable training of some part of the model, use `freeze!`

. This is a temporary modification, reversed by `thaw!`

:

```
Flux.freeze!(opt_state.layers.enc)
# Now training won't update parameters in bimodel.layers.enc
train!(loss, bimodel, data, opt_state)
# Un-freeze the entire model:
Flux.thaw!(opt_state)
```

While `adjust!`

and `freeze!`

/`thaw!`

make temporary modifications to the optimiser state, permanently removing some fields of a new layer type from training is usually done when defining the layer, by calling for example `@layer`

`NewLayer trainable=(weight,)`

.