To actually train a model we need three things:
- A objective function, that evaluates how well a model is doing given some input data.
- A collection of data points that will be provided to the objective function.
- An optimiser that will update the model parameters appropriately.
With these we can call
Flux.train!(objective, data, opt)
There are plenty of examples in the model zoo.
The objective function must return a number representing how far the model is from its target – the loss of the model. The
loss function that we defined in basics will work as an objective. We can also define an objective in terms of some model:
m = Chain( Dense(784, 32, σ), Dense(32, 10), softmax) loss(x, y) = Flux.mse(m(x), y) # later Flux.train!(loss, data, opt)
The objective will almost always be defined in terms of some cost function that measures the distance of the prediction
m(x) from the target
y. Flux has several of these built in, like
mse for mean squared error or
crossentropy for cross entropy loss, but you can calculate it however you want.
data argument provides a collection of data to train with (usually a set of inputs
x and target outputs
y). For example, here's a dummy data set with only one data point:
x = rand(784) y = rand(10) data = [(x, y)]
Flux.train! will call
loss(x, y), calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:
data = [(x, y), (x, y), (x, y)] # Or equivalently data = Iterators.repeated((x, y), 3)
It's common to load the
ys separately. In this case you can use
xs = [rand(784), rand(784), rand(784)] ys = [rand( 10), rand( 10), rand( 10)] data = zip(xs, ys)
Note that, by default,
train! only loops over the data once (a single "epoch"). A convenient way to run multiple epochs from the REPL is provided by
julia> using Flux: @epochs julia> @epochs 2 println("hello") INFO: Epoch 1 hello INFO: Epoch 2 hello julia> @epochs 2 Flux.train!(...) # Train for two epochs
train! takes an additional argument,
cb, that's used for callbacks so that you can observe the training process. For example:
train!(objective, data, opt, cb = () -> println("training"))
Callbacks are called for every batch of training data. You can slow this down using
Flux.throttle(f, timeout) which prevents
f from being called more than once every
A more typical callback might look like this:
test_x, test_y = # ... create single batch of test data ... evalcb() = @show(loss(test_x, test_y)) Flux.train!(objective, data, opt, cb = throttle(evalcb, 5))