To actually train a model we need four things:
- A objective function, that evaluates how well a model is doing given some input data.
- The trainable parameters of the model.
- A collection of data points that will be provided to the objective function.
- An optimiser that will update the model parameters appropriately.
With these we can call
train!(loss, params, data, opt; cb)
For each datapoint
data compute the gradient of
loss(d...) through backpropagation and call the optimizer
In case datapoints
d are of numeric array type, assume no splatting is needed and compute the gradient of
A callback is given with the keyword argument
cb. For example, this will print "training" every 10 seconds (using
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
The callback can call
Flux.stop to interrupt the training loop.
Multiple optimisers and callbacks can be passed to
cb as arrays.
There are plenty of examples in the model zoo.
The objective function must return a number representing how far the model is from its target – the loss of the model. The
loss function that we defined in basics will work as an objective. We can also define an objective in terms of some model:
m = Chain( Dense(784, 32, σ), Dense(32, 10), softmax) loss(x, y) = Flux.mse(m(x), y) ps = Flux.params(m) # later Flux.train!(loss, ps, data, opt)
The objective will almost always be defined in terms of some cost function that measures the distance of the prediction
m(x) from the target
y. Flux has several of these built in, like
mse for mean squared error or
crossentropy for cross entropy loss, but you can calculate it however you want. For a list of all built-in loss functions, check out the layer reference.
At first glance it may seem strange that the model that we want to train is not part of the input arguments of
Flux.train! too. However the target of the optimizer is not the model itself, but the objective function that represents the departure between modelled and observed data. In other words, the model is implicitly defined in the objective function, and there is no need to give it explicitly. Passing the objective function instead of the model and a cost function separately provides more flexibility, and the possibility of optimizing the calculations.
The model to be trained must have a set of tracked parameters that are used to calculate the gradients of the objective function. In the basics section it is explained how to create models with such parameters. The second argument of the function
Flux.train! must be an object containing those parameters, which can be obtained from a model
Such an object contains a reference to the model's parameters, not a copy, such that after their training, the model behaves according to their updated values.
Handling all the parameters on a layer by layer basis is explained in the Layer Helpers section. Also, for freezing model parameters, see the Advanced Usage Guide.
data argument provides a collection of data to train with (usually a set of inputs
x and target outputs
y). For example, here's a dummy data set with only one data point:
x = rand(784) y = rand(10) data = [(x, y)]
Flux.train! will call
loss(x, y), calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:
data = [(x, y), (x, y), (x, y)] # Or equivalently using IterTools: ncycle data = ncycle([(x, y)], 3)
It's common to load the
ys separately. In this case you can use
xs = [rand(784), rand(784), rand(784)] ys = [rand( 10), rand( 10), rand( 10)] data = zip(xs, ys)
Training data can be conveniently partitioned for mini-batch training using the
X = rand(28, 28, 60000) Y = rand(0:9, 60000) data = DataLoader(X, Y, batchsize=128)
Note that, by default,
train! only loops over the data once (a single "epoch"). A convenient way to run multiple epochs from the REPL is provided by
julia> using Flux: @epochs julia> @epochs 2 println("hello") INFO: Epoch 1 hello INFO: Epoch 2 hello julia> @epochs 2 Flux.train!(...) # Train for two epochs
@epochs N body
N times. Mainly useful for quickly doing multiple epochs of training in a REPL.
julia> Flux.@epochs 2 println("hello") [ Info: Epoch 1 hello [ Info: Epoch 2 hello
train! takes an additional argument,
cb, that's used for callbacks so that you can observe the training process. For example:
train!(objective, ps, data, opt, cb = () -> println("training"))
Callbacks are called for every batch of training data. You can slow this down using
Flux.throttle(f, timeout) which prevents
f from being called more than once every
A more typical callback might look like this:
test_x, test_y = # ... create single batch of test data ... evalcb() = @show(loss(test_x, test_y)) Flux.train!(objective, ps, data, opt, cb = throttle(evalcb, 5))
Flux.stop() in a callback will exit the training loop early.
cb = function () accuracy() > 0.9 && Flux.stop() end
Custom Training loops
Flux.train! function can be very convenient, especially for simple problems. Its also very flexible with the use of callbacks. But for some problems its much cleaner to write your own custom training loop. An example follows that works similar to the default
Flux.train but with no callbacks. You don't need callbacks if you just code the calls to your functions directly into the loop. E.g. in the places marked with comments.
function my_custom_train!(loss, ps, data, opt) ps = Params(ps) for d in data gs = gradient(ps) do training_loss = loss(d...) # Insert what ever code you want here that needs Training loss, e.g. logging return training_loss end # insert what ever code you want here that needs gradient # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge update!(opt, ps, gs) # Here you might like to check validation set accuracy, and break out to do early stopping end end
You could simplify this further, for example by hard-coding in the loss function.