Let's put FluxTraining.jl to train a model on the MNIST dataset.
MNIST is simple enough that we can focus on the part where FluxTraining.jl comes in, the training. If you want to see examples of using FluxTraining.jl on larger datasets, see the documentation of FastAI.jl .
If you want to run this tutorial yourself, you can find the notebook file here .
To make data loading and batching a bit easier, we'll install an additional dependency:
using
Pkg
;
Pkg
.
add
(
[
"
MLUtils
"
]
)
Now we can import everything we'll need.
using
MLUtils
:
splitobs
,
unsqueeze
using
MLDatasets
:
MNIST
using
Flux
using
Flux
:
onehotbatch
using
Flux
.
Data
:
DataLoader
using
FluxTraining
There are 4 pieces that you always need to construct and train a
Learner
:
a model
data
an optimizer; and
a loss function
Learner
Let's look at the data first.
FluxTraining.jl is agnostic of the data source. The only requirements are:
it is iterable and each iteration returns a tuple
(xs, ys)
the model can take in
xs
, i.e.
model(xs)
works; and
the loss function can take model outputs and
ys
, i.e.
lossfn(model(xs), ys)
returns a scalar
Glossing over the details as it's not the focus of this tutorial, here's the code for getting a data iterator of the MNIST dataset. We use
DataLoaders.DataLoader
to create an iterator of batches from our dataset.
data
=
MNIST
(
:
train
)
[
:
]
const
LABELS
=
0
:
9
# unsqueeze to reshape from (28, 28, numobs) to (28, 28, 1, numobs)
function
preprocess
(
(
data
,
targets
)
)
return
unsqueeze
(
data
,
3
)
,
onehotbatch
(
targets
,
LABELS
)
end
# traindata and testdata contain both inputs (pixel values) and targets (correct labels)
traindata
=
MNIST
(
Float32
,
:
train
)
[
:
]
|>
preprocess
testdata
=
MNIST
(
Float32
,
:
test
)
[
:
]
|>
preprocess
# create iterators
trainiter
,
testiter
=
DataLoader
(
traindata
,
batchsize
=
128
)
,
DataLoader
(
testdata
,
batchsize
=
256
)
;
Next, let's create a simple Flux.jl model that we'll train to classify the MNIST digits.
model
=
Chain
(
Conv
(
(
3
,
3
)
,
1
=>
16
,
relu
,
pad
=
1
,
stride
=
2
)
,
Conv
(
(
3
,
3
)
,
16
=>
32
,
relu
,
pad
=
1
)
,
GlobalMeanPool
(
)
,
Flux
.
flatten
,
Dense
(
32
,
10
)
,
)
Chain(
Conv((3, 3), 1 => 16, relu, pad=1, stride=2), # 160 parameters
Conv((3, 3), 16 => 32, relu, pad=1), # 4_640 parameters
GlobalMeanPool(),
Flux.flatten,
Dense(32 => 10), # 330 parameters
) # Total: 6 arrays, 5_130 parameters, 20.867 KiB.
We'll use categorical cross entropy as a loss function and ADAM as an optimizer.
lossfn
=
Flux
.
Losses
.
logitcrossentropy
optimizer
=
Flux
.
ADAM
(
)
;
Now we're ready to create a
Learner
. At this point you can also add any callbacks, like
ToGPU
to run the training on your GPU if you have one available. Some callbacks are also
included by default.
Since we're classifying digits, we also use the
Metrics
callback to track the accuracy of the model's predictions:
Learner()
With a
Learner
in place, training is as simple as calling
fit!
(learner, nepochs, dataiters)
.
FluxTraining
.
fit!
(
learner
,
10
,
(
trainiter
,
testiter
)
)
Epoch 1 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 1.0 │ 1.92511 │ 0.30643 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 1 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 1.0 │ 1.56384 │ 0.429 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 2 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 2.0 │ 1.40887 │ 0.53872 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 2 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 2.0 │ 1.25592 │ 0.57578 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 3 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 3.0 │ 1.17799 │ 0.64023 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 3 ValidationPhase() ...
┌─────────────────┬───────┬────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼────────┼──────────┤
│ ValidationPhase │ 3.0 │ 1.0728 │ 0.63809 │
└─────────────────┴───────┴────────┴──────────┘
Epoch 4 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 4.0 │ 1.02755 │ 0.69311 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 4 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 4.0 │ 0.95165 │ 0.67871 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 5 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 5.0 │ 0.91992 │ 0.72615 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 5 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 5.0 │ 0.86318 │ 0.70791 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 6 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 6.0 │ 0.83481 │ 0.75419 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 6 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 6.0 │ 0.79026 │ 0.73604 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 7 TrainingPhase() ...
┌───────────────┬───────┬────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼────────┼──────────┤
│ TrainingPhase │ 7.0 │ 0.7636 │ 0.77762 │
└───────────────┴───────┴────────┴──────────┘
Epoch 7 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 7.0 │ 0.72974 │ 0.75869 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 8 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 8.0 │ 0.70309 │ 0.79626 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 8 ValidationPhase() ...
┌─────────────────┬───────┬────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼────────┼──────────┤
│ ValidationPhase │ 8.0 │ 0.6784 │ 0.77363 │
└─────────────────┴───────┴────────┴──────────┘
Epoch 9 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 9.0 │ 0.65131 │ 0.81166 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 9 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 9.0 │ 0.63317 │ 0.79014 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 10 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 10.0 │ 0.60669 │ 0.82449 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 10 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 10.0 │ 0.59359 │ 0.80527 │
└─────────────────┴───────┴─────────┴──────────┘
The following pages link back here: