Let’s put FluxTraining . jl to train a model on the MNIST dataset.
MNIST is simple enough that we can focus on the part where FluxTraining . jl comes in, the training. If you want to see examples of using FluxTraining.jl on larger datasets, see the documentation of FastAI . jl .
If you want to run this tutorial yourself, you can find the notebook file here .
To make data loading and batching a bit easier, we’ll install some additional dependencies:
using
Pkg
;
Pkg
.
add
(
[
"
MLDataPattern
"
,
"
DataLoaders
"
]
)
Now we can import everything we’ll need.
using
DataLoaders
:
DataLoader
using
MLDataPattern
:
splitobs
using
Flux
using
FluxTraining
There are 4 pieces that you always need to construct and train a
Learner
:
a model
data
an optimizer; and
a loss function
Learner
Let’s look at the data first.
FluxTraining . jl is agnostic of the data source. The only requirements are:
it is iterable and each iteration returns a tuple
(xs, ys)
the model can take in
xs
, i.e.
model(xs)
works; and
the loss function can take model outputs and
ys
, i.e.
lossfn(model(xs), ys)
returns a scalar
Glossing over the details as it’s not the focus of this tutorial, here’s the code for getting a data iterator of the MNIST dataset. We use
DataLoaders.DataLoader
to create an iterator of batches from our dataset.
xs
,
ys
=
(
# convert each image into h*w*1 array of floats
[
Float32
.
(
reshape
(
img
,
28
,
28
,
1
)
)
for
img
in
Flux
.
Data
.
MNIST
.
images
(
)
]
,
# one-hot encode the labels
[
Float32
.
(
Flux
.
onehot
(
y
,
0
:
9
)
)
for
y
in
Flux
.
Data
.
MNIST
.
labels
(
)
]
,
)
# split into training and validation sets
traindata
,
valdata
=
splitobs
(
(
xs
,
ys
)
)
# create iterators
trainiter
,
valiter
=
DataLoader
(
traindata
,
128
,
buffered
=
false
)
,
DataLoader
(
valdata
,
256
,
buffered
=
false
)
;
Next, let’s create a simple Flux . jl model that we’ll train to classify the MNIST digits.
model
=
Chain
(
Conv
(
(
3
,
3
)
,
1
=>
16
,
relu
,
pad
=
1
,
stride
=
2
)
,
Conv
(
(
3
,
3
)
,
16
=>
32
,
relu
,
pad
=
1
)
,
GlobalMeanPool
(
)
,
Flux
.
flatten
,
Dense
(
32
,
10
)
,
)
Chain(Conv((3, 3), 1=>16, relu), Conv((3, 3), 16=>32, relu), GlobalMeanPool(), flatten, Dense(32, 10))
We’ll use categorical cross entropy as a loss function and ADAM as an optimizer .
lossfn
=
Flux
.
Losses
.
logitcrossentropy
optimizer
=
Flux
.
ADAM
(
)
;
Now we’re ready to create a
Learner
. At this point you can also add any callbacks, like
ToGPU
to run the training on your GPU if you have one available. Some callbacks are also
included by default
.
Since we’re classifying digits, we also use the
Metrics
callback to track the accuracy of the model’s predictions:
With a
Learner
inplace, training is as simple as calling
fit!
(learner, nepochs, dataiters)
.
FluxTraining
.
fit!
(
learner
,
10
,
(
trainiter
,
validiter
)
)
Epoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:46
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 1.0 │ 2.04939 │ 0.25204 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 1.0 │ 1.70353 │ 0.3821 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:19
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 2.0 │ 1.58615 │ 0.44849 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 2.0 │ 1.44792 │ 0.50544 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 3.0 │ 1.36495 │ 0.57273 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 3.0 │ 1.25941 │ 0.59525 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:20
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 4.0 │ 1.18935 │ 0.64891 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼────────┼──────────┤
│ ValidationPhase │ 4.0 │ 1.1076 │ 0.66347 │
└─────────────────┴───────┴────────┴──────────┘
Epoch 5 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:19
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 5.0 │ 1.05506 │ 0.69386 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 5 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 5.0 │ 0.99203 │ 0.70275 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 6 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 6.0 │ 0.95282 │ 0.72533 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 6 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 6.0 │ 0.90209 │ 0.73058 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 7 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:19
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 7.0 │ 0.87621 │ 0.74563 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 7 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 7.0 │ 0.83402 │ 0.74781 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 8 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 8.0 │ 0.81399 │ 0.76282 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 8 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 8.0 │ 0.77623 │ 0.76568 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 9 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 9.0 │ 0.76236 │ 0.77835 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 9 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 9.0 │ 0.72606 │ 0.78079 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 10 TrainingPhase(): 100%|█████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 10.0 │ 0.71684 │ 0.79175 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 10 ValidationPhase(): 100%|███████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 10.0 │ 0.68353 │ 0.79449 │
└─────────────────┴───────┴─────────┴──────────┘
Learner()
The following pages link back here: