To actually train a model we need four things:

  • A objective function, that evaluates how well a model is doing given some input data.
  • The trainable parameters of the model.
  • A collection of data points that will be provided to the objective function.
  • An optimiser that will update the model parameters appropriately.

Training a model is typically an iterative process, where we go over the data set, calculate the objective function over the data points, and optimise that. This can be visualised in the form of a simple loop.

for d in datapoints

  # `d` should produce a collection of arguments
  # to the loss function

  # Calculate the gradients of the parameters
  # with respect to the loss function
  grads = Flux.gradient(parameters) do

  # Update the parameters based on the chosen
  # optimiser (opt)
  Flux.Optimise.update!(opt, parameters, grads)

To make it easy, Flux defines train!:

train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])

Uses a loss function and training data to improve the model's parameters according to a particular optimisation rule opt.

For each d in data, first the gradient of the loss is computed like this:

    gradient(() -> loss(d...), pars)  # if d isa Tuple
    gradient(() -> loss(d), pars)     # otherwise

Here pars is produced by calling Flux.params on your model. (Or just on the layers you want to train, like train!(loss, params(model[1:end-2]), data, opt).) This is the "implicit" style of parameter handling.

This gradient is then used by optimizer opt to update the parameters:

    update!(opt, pars, grads)

The optimiser should be from the Flux.Optimise module (see Optimisers). Different optimisers can be combined using Flux.Optimise.Optimiser.

This training loop iterates through data once. You can use @epochs to do this several times, or use for instance Iterators.repeat to make a longer data iterator.


Callbacks are given with the keyword argument cb. For example, this will print "training" every 10 seconds (using Flux.throttle):

    train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))

The callback can call Flux.stop to interrupt the training loop.

Multiple callbacks can be passed to cb as array.


There are plenty of examples in the model zoo, and more information can be found on Custom Training Loops.

Loss Functions

The objective function must return a number representing how far the model is from its target – the loss of the model. The loss function that we defined in basics will work as an objective. In addition to custom losses, a model can be trained in conjunction with the commonly used losses that are grouped under the Flux.Losses module. We can also define an objective in terms of some model:

m = Chain(
  Dense(784 => 32, σ),
  Dense(32 => 10), softmax)

loss(x, y) = Flux.Losses.mse(m(x), y)
ps = Flux.params(m)

# later
Flux.train!(loss, ps, data, opt)

The objective will almost always be defined in terms of some cost function that measures the distance of the prediction m(x) from the target y. Flux has several of these built-in, like mse for mean squared error or crossentropy for cross-entropy loss, but you can calculate it however you want. For a list of all built-in loss functions, check out the losses reference.

At first glance, it may seem strange that the model that we want to train is not part of the input arguments of Flux.train! too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility and the possibility of optimizing the calculations.

Model parameters

The model to be trained must have a set of tracked parameters that are used to calculate the gradients of the objective function. In the basics section it is explained how to create models with such parameters. The second argument of the function Flux.train! must be an object containing those parameters, which can be obtained from a model m as Flux.params(m).

Such an object contains a reference to the model's parameters, not a copy, such that after their training, the model behaves according to their updated values.

Handling all the parameters on a layer-by-layer basis is explained in the Layer Helpers section. For freezing model parameters, see the Advanced Usage Guide.


Given a model or specific layers from a model, create a Params object pointing to its trainable parameters.

This can be used with the gradient function, see Taking Gradients, or as input to the Flux.train! function.

The behaviour of params on custom types can be customized using Functors.@functor or Flux.trainable.


julia> using Flux: params

julia> params(Chain(Dense(ones(2,3)), softmax))  # unpacks Flux models
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])

julia> bn = BatchNorm(2, relu)
BatchNorm(2, relu)  # 4 parameters, plus 4 non-trainable

julia> params(bn)  # only the trainable parameters
Params([Float32[0.0, 0.0], Float32[1.0, 1.0]])

julia> params([1, 2, 3], [4])  # one or more arrays of numbers
Params([[1, 2, 3], [4]])

julia> params([[1, 2, 3], [4]])  # unpacks array of arrays
Params([[1, 2, 3], [4]])

julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin))  # ignores scalars, unpacks NamedTuples
Params([[2 2], [3, 3, 3]])


The data argument of train! provides a collection of data to train with (usually a set of inputs x and target outputs y). For example, here's a dummy dataset with only one data point:

x = rand(784)
y = rand(10)
data = [(x, y)]

Flux.train! will call loss(x, y), calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:

data = [(x, y), (x, y), (x, y)]
# Or equivalently
using IterTools: ncycle
data = ncycle([(x, y)], 3)

It's common to load the xs and ys separately. Here you can use zip:

xs = [rand(784), rand(784), rand(784)]
ys = [rand( 10), rand( 10), rand( 10)]
data = zip(xs, ys)

Training data can be conveniently partitioned for mini-batch training using the Flux.Data.DataLoader type:

X = rand(28, 28, 60000)
Y = rand(0:9, 60000)
data = DataLoader((X, Y), batchsize=128) 

Note that, by default, train! only loops over the data once (a single "epoch"). A convenient way to run multiple epochs from the REPL is provided by @epochs.

julia> using Flux: @epochs

julia> @epochs 2 println("hello")
[ Info: Epoch 1
[ Info: Epoch 2

julia> @epochs 2 Flux.train!(...)
# Train for two epochs
@epochs N body

Run body N times. Mainly useful for quickly doing multiple epochs of training in a REPL.


The macro @epochs will be removed from Flux 0.14. Please just write an ordinary for loop.


julia> Flux.@epochs 2 println("hello")
[ Info: Epoch 1
[ Info: Epoch 2


train! takes an additional argument, cb, that's used for callbacks so that you can observe the training process. For example:

train!(objective, ps, data, opt, cb = () -> println("training"))

Callbacks are called for every batch of training data. You can slow this down using Flux.throttle(f, timeout) which prevents f from being called more than once every timeout seconds.

A more typical callback might look like this:

test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))
throttled_cb = throttle(evalcb, 5)
Flux.@epochs 20 Flux.train!(objective, ps, data, opt, cb = throttled_cb)

Calling Flux.stop() in a callback will exit the training loop early.

cb = function ()
  accuracy() > 0.9 && Flux.stop()

Custom Training loops

The Flux.train! function can be very convenient, especially for simple problems. For some problems, however, it's much cleaner to write your own custom training loop. An example follows that works similar to the default Flux.train but with no callbacks. You don't need callbacks if you just code the calls to your functions directly into the loop. E.g. in the places marked with comments.

function my_custom_train!(loss, ps, data, opt)
  # training_loss is declared local so it will be available for logging outside the gradient calculation.
  local training_loss
  ps = Params(ps)
  for d in data
    gs = gradient(ps) do
      training_loss = loss(d...)
      # Code inserted here will be differentiated, unless you need that gradient information
      # it is better to do the work outside this block.
      return training_loss
    # Insert whatever code you want here that needs training_loss, e.g. logging.
    # logging_callback(training_loss)
    # Insert whatever code you want here that needs gradients.
    # e.g. logging histograms with TensorBoardLogger.jl to check for exploding gradients.
    update!(opt, ps, gs)
    # Here you might like to check validation set accuracy, and break out to do early stopping.

You could simplify this further, for example by hard-coding in the loss function.

Another possibility is to use Zygote.pullback to access the training loss and the gradient simultaneously.

function my_custom_train!(loss, ps, data, opt)
  ps = Params(ps)
  for d in data
    # back is a method that computes the product of the gradient so far with its argument.
    train_loss, back = Zygote.pullback(() -> loss(d...), ps)
    # Insert whatever code you want here that needs training_loss, e.g. logging.
    # logging_callback(training_loss)
    # Apply back() to the correct type of 1.0 to get the gradient of loss.
    gs = back(one(train_loss))
    # Insert whatever code you want here that needs gradient.
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
    update!(opt, ps, gs)
    # Here you might like to check validation set accuracy, and break out to do early stopping.