Training an image classifier

Let's put FluxTraining.jl to train a model on the MNIST dataset.

MNIST is simple enough that we can focus on the part where FluxTraining.jl comes in, the training. If you want to see examples of using FluxTraining.jl on larger datasets, see the documentation of FastAI.jl .

Setup

If you want to run this tutorial yourself, you can find the notebook file here .

To make data loading and batching a bit easier, we'll install an additional dependency:


			
			
			
			
			using
			
			 
			Pkg
			;
			 
			
			
			Pkg
			.
			
			add
			(
			
			[
			
			"
			MLUtils
			"
			]
			)

Now we can import everything we'll need.


			
			
			
			
			using
			
			 
			MLUtils
			:
			
			 
			splitobs
			,
			
			 
			unsqueeze
			

			
			
			using
			
			 
			MLDatasets
			:
			
			 
			MNIST
			

			
			using
			
			 
			Flux
			

			
			
			using
			
			 
			Flux
			:
			
			 
			onehotbatch
			

			
			
			using
			
			 
			Flux
			.
			Data
			:
			
			 
			DataLoader
			

			
			using
			
			 
			FluxTraining

Overview

There are 4 pieces that you always need to construct and train a Learner:

  • a model

  • data

  • an optimizer; and

  • a loss function

Building a Learner

Let's look at the data first.

FluxTraining.jl is agnostic of the data source. The only requirements are:

  • it is iterable and each iteration returns a tuple (xs, ys)

  • the model can take in xs, i.e. model(xs) works; and

  • the loss function can take model outputs and ys, i.e. lossfn(model(xs), ys) returns a scalar

Glossing over the details as it's not the focus of this tutorial, here's the code for getting a data iterator of the MNIST dataset. We use DataLoaders.DataLoader to create an iterator of batches from our dataset.


			
			
			
			data
			 
			=
			 
			
			
			MNIST
			(
			
			:
			train
			)
			[
			:
			]
			

			

			
			const
			
			 
			LABELS
			 
			=
			
			 
			0
			:
			9
			

			

			# unsqueeze to reshape from (28, 28, numobs) to (28, 28, 1, numobs)
			

			
			function
			 
			
			preprocess
			(
			
			(
			data
			,
			 
			targets
			)
			)
			
			
    
			
			return
			
			 
			
			unsqueeze
			(
			data
			,
			 
			3
			)
			,
			 
			
			onehotbatch
			(
			targets
			,
			 
			LABELS
			)
			

			end
			

			

			

			# traindata and testdata contain both inputs (pixel values) and targets (correct labels)
			

			
			traindata
			 
			=
			
			 
			
			
			MNIST
			(
			Float32
			,
			 
			
			:
			train
			)
			[
			:
			]
			 
			|>
			 
			preprocess
			

			
			testdata
			 
			=
			
			 
			
			
			MNIST
			(
			Float32
			,
			 
			
			:
			test
			)
			[
			:
			]
			 
			|>
			 
			preprocess
			
			

			

			# create iterators
			

			
			
			trainiter
			,
			 
			testiter
			 
			=
			
			 
			
			DataLoader
			(
			traindata
			,
			 
			
			batchsize
			=
			128
			)
			,
			 
			
			DataLoader
			(
			testdata
			,
			 
			
			batchsize
			=
			256
			)
			;

Next, let's create a simple Flux.jl model that we'll train to classify the MNIST digits.


			
			
			
			model
			 
			=
			 
			
			Chain
			(
			
    
			
			Conv
			(
			
			(
			3
			,
			 
			3
			)
			,
			 
			
			1
			 
			=>
			 
			16
			,
			 
			relu
			,
			 
			
			pad
			 
			=
			 
			1
			,
			 
			
			stride
			 
			=
			 
			2
			)
			,
			
    
			
			Conv
			(
			
			(
			3
			,
			 
			3
			)
			,
			 
			
			16
			 
			=>
			 
			32
			,
			 
			relu
			,
			 
			
			pad
			 
			=
			 
			1
			)
			,
			
    
			
			GlobalMeanPool
			(
			)
			,
			
    
			
			Flux
			.
			
			flatten
			,
			
    
			
			Dense
			(
			32
			,
			 
			10
			)
			,
			

			)

			Chain(
  Conv((3, 3), 1 => 16, relu, pad=1, stride=2),  # 160 parameters
  Conv((3, 3), 16 => 32, relu, pad=1),  # 4_640 parameters
  GlobalMeanPool(),
  Flux.flatten,
  Dense(32 => 10),                      # 330 parameters
)                   # Total: 6 arrays, 5_130 parameters, 20.867 KiB.

We'll use categorical cross entropy as a loss function and ADAM as an optimizer.


			
			
			
			lossfn
			 
			=
			 
			
			
			Flux
			.
			
			Losses
			.
			
			logitcrossentropy
			
			

			
			optimizer
			 
			=
			 
			
			
			Flux
			.
			
			ADAM
			(
			)
			;

Now we're ready to create a Learner. At this point you can also add any callbacks, like ToGPU to run the training on your GPU if you have one available. Some callbacks are also included by default.

Since we're classifying digits, we also use the Metrics callback to track the accuracy of the model's predictions:


			
			
			
			learner
			 
			=
			 
			

	
			Learner
			(
			model
			,
			 
			lossfn
			
			;
			 
			
			callbacks
			=
			
			[
			

	
			ToGPU
			(
			)
			,
			 
			

	
			Metrics
			(

	
			accuracy
			)
			]
			,
			 
			optimizer
			)

			Learner()

Training

With a Learner in place, training is as simple as calling fit! (learner, nepochs, dataiters).


			
			
			
			
			FluxTraining
			.
			

	
			fit!
			(
			learner
			,
			 
			10
			,
			 
			
			(
			trainiter
			,
			 
			testiter
			)
			)

			Epoch 1 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   1.0 │ 1.92511 │  0.30643 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 1 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   1.0 │ 1.56384 │    0.429 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 2 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   2.0 │ 1.40887 │  0.53872 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 2 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   2.0 │ 1.25592 │  0.57578 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 3 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   3.0 │ 1.17799 │  0.64023 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 3 ValidationPhase() ...
┌─────────────────┬───────┬────────┬──────────┐
│           Phase  Epoch    Loss  Accuracy │
├─────────────────┼───────┼────────┼──────────┤
│ ValidationPhase │   3.0 │ 1.0728 │  0.63809 │
└─────────────────┴───────┴────────┴──────────┘
Epoch 4 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   4.0 │ 1.02755 │  0.69311 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 4 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   4.0 │ 0.95165 │  0.67871 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 5 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   5.0 │ 0.91992 │  0.72615 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 5 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   5.0 │ 0.86318 │  0.70791 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 6 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   6.0 │ 0.83481 │  0.75419 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 6 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   6.0 │ 0.79026 │  0.73604 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 7 TrainingPhase() ...
┌───────────────┬───────┬────────┬──────────┐
│         Phase  Epoch    Loss  Accuracy │
├───────────────┼───────┼────────┼──────────┤
│ TrainingPhase │   7.0 │ 0.7636 │  0.77762 │
└───────────────┴───────┴────────┴──────────┘
Epoch 7 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   7.0 │ 0.72974 │  0.75869 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 8 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   8.0 │ 0.70309 │  0.79626 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 8 ValidationPhase() ...
┌─────────────────┬───────┬────────┬──────────┐
│           Phase  Epoch    Loss  Accuracy │
├─────────────────┼───────┼────────┼──────────┤
│ ValidationPhase │   8.0 │ 0.6784 │  0.77363 │
└─────────────────┴───────┴────────┴──────────┘
Epoch 9 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   9.0 │ 0.65131 │  0.81166 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 9 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   9.0 │ 0.63317 │  0.79014 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 10 TrainingPhase() ...
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │  10.0 │ 0.60669 │  0.82449 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 10 ValidationPhase() ...
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │  10.0 │ 0.59359 │  0.80527 │
└─────────────────┴───────┴─────────┴──────────┘