Training

Training

To actually train a model we need three things:

With these we can call Flux.train!:

Flux.train!(modelLoss, data, opt)

There are plenty of examples in the model zoo.

Loss Functions

The loss that we defined in basics is completely valid for training. We can also define a loss in terms of some model:

m = Chain(
  Dense(784, 32, σ),
  Dense(32, 10), softmax)

# Model loss function
loss(x, y) = Flux.mse(m(x), y)

# later
Flux.train!(loss, data, opt)

The loss will almost always be defined in terms of some cost function that measures the distance of the prediction m(x) from the target y. Flux has several of these built in, like mse for mean squared error or crossentropy for cross entropy loss, but you can calculate it however you want.

Datasets

The data argument provides a collection of data to train with (usually a set of inputs x and target outputs y). For example, here's a dummy data set with only one data point:

x = rand(784)
y = rand(10)
data = [(x, y)]

Flux.train! will call loss(x, y), calculate gradients, update the weights and then move on to the next data point if there is one. We can train the model on the same data three times:

data = [(x, y), (x, y), (x, y)]
# Or equivalently
data = Iterators.repeated((x, y), 3)

It's common to load the xs and ys separately. In this case you can use zip:

xs = [rand(784), rand(784), rand(784)]
ys = [rand( 10), rand( 10), rand( 10)]
data = zip(xs, ys)

Callbacks

train! takes an additional argument, cb, that's used for callbacks so that you can observe the training process. For example:

train!(loss, data, opt, cb = () -> println("training"))

Callbacks are called for every batch of training data. You can slow this down using Flux.throttle(f, timeout) which prevents f from being called more than once every timeout seconds.

A more typical callback might look like this:

test_x, test_y = # ... create single batch of test data ...
evalcb() = @show(loss(test_x, test_y))

Flux.train!(loss, data, opt,
            cb = throttle(evalcb, 5))