Training an image classifier

Let’s put FluxTraining . jl to train a model on the MNIST dataset.

MNIST is simple enough that we can focus on the part where FluxTraining . jl comes in, the training. If you want to see examples of using FluxTraining.jl on larger datasets, see the documentation of FastAI . jl .

Setup

If you want to run this tutorial yourself, you can find the notebook file here .

To make data loading and batching a bit easier, we’ll install some additional dependencies:


			
			
			
			
			using
			
			
			 
			
			Pkg
			
			;
			
			 
			
			
			
			Pkg
			
			.
			
			
			add
			
			(
			
			
			[
			
			"
			
			MLDataPattern
			
			"
			
			,
			
			 
			
			"
			
			DataLoaders
			
			"
			
			]
			
			)

Now we can import everything we’ll need.


			
			
			
			
			using
			
			
			 
			
			DataLoaders
			
			:
			
			
			 
			
			DataLoader
			
			

			
			
			
			using
			
			
			 
			
			MLDataPattern
			
			:
			
			
			 
			
			splitobs
			
			

			
			
			using
			
			
			 
			
			Flux
			
			

			
			
			using
			
			
			 
			
	
		
			FluxTraining

Overview

There are 4 pieces that you always need to construct and train a Learner :

  • a model

  • data

  • an optimizer; and

  • a loss function

Building a Learner

Let’s look at the data first.

FluxTraining . jl is agnostic of the data source. The only requirements are:

  • it is iterable and each iteration returns a tuple (xs, ys)

  • the model can take in xs , i.e. model(xs) works; and

  • the loss function can take model outputs and ys , i.e. lossfn(model(xs), ys) returns a scalar

Glossing over the details as it’s not the focus of this tutorial, here’s the code for getting a data iterator of the MNIST dataset. We use DataLoaders.DataLoader to create an iterator of batches from our dataset.


			
			
			
			
			xs
			
			,
			
			 
			
			ys
			
			 
			
			=
			
			 
			
			
			(
			
			
    
			
			# convert each image into h*w*1 array of floats 
			
			
    
			
			
			[
			
			
			
			Float32
			
			.
			
			
			(
			
			
			reshape
			
			(
			
			img
			
			,
			
			 
			
			28
			
			,
			
			 
			
			28
			
			,
			
			 
			
			1
			
			)
			
			)
			
			 
			
			for
			
			
			 
			
			img
			
			 
			
			in
			
			 
			
			
			
			
			
			Flux
			
			.
			
			
			Data
			
			.
			
			
			MNIST
			
			.
			
			
			images
			
			(
			
			)
			
			]
			
			,
			
			
    
			
			# one-hot encode the labels
			
			
    
			
			
			[
			
			
			
			Float32
			
			.
			
			
			(
			
			
			
			Flux
			
			.
			
			
			onehot
			
			(
			
			y
			
			,
			
			 
			
			
			0
			
			:
			
			9
			
			)
			
			)
			
			 
			
			for
			
			
			 
			
			y
			
			 
			
			in
			
			 
			
			
			
			
			
			Flux
			
			.
			
			
			Data
			
			.
			
			
			MNIST
			
			.
			
			
			labels
			
			(
			
			)
			
			]
			
			,
			
			

			
			)
			
			

			
			

			
			# split into training and validation sets
			
			

			
			
			
			traindata
			
			,
			
			 
			
			valdata
			
			 
			
			=
			
			 
			
			
			splitobs
			
			(
			
			
			(
			
			xs
			
			,
			
			 
			
			ys
			
			)
			
			)
			
			
			

			
			

			
			# create iterators
			
			

			
			
			
			trainiter
			
			,
			
			 
			
			valiter
			
			 
			
			=
			
			
			 
			
			
			DataLoader
			
			(
			
			traindata
			
			,
			
			 
			
			128
			
			,
			
			 
			
			
			buffered
			
			=
			
			false
			
			)
			
			,
			
			 
			
			
			DataLoader
			
			(
			
			valdata
			
			,
			
			 
			
			256
			
			,
			
			 
			
			
			buffered
			
			=
			
			false
			
			)
			
			;

Next, let’s create a simple Flux . jl model that we’ll train to classify the MNIST digits.


			
			
			
			model
			
			 
			
			=
			
			 
			
			
			Chain
			
			(
			
			
    
			
			
			Conv
			
			(
			
			
			(
			
			3
			
			,
			
			 
			
			3
			
			)
			
			,
			
			 
			
			
			1
			
			 
			
			=>
			
			 
			
			16
			
			,
			
			 
			
			relu
			
			,
			
			 
			
			
			pad
			
			 
			
			=
			
			 
			
			1
			
			,
			
			 
			
			
			stride
			
			 
			
			=
			
			 
			
			2
			
			)
			
			,
			
			
    
			
			
			Conv
			
			(
			
			
			(
			
			3
			
			,
			
			 
			
			3
			
			)
			
			,
			
			 
			
			
			16
			
			 
			
			=>
			
			 
			
			32
			
			,
			
			 
			
			relu
			
			,
			
			 
			
			
			pad
			
			 
			
			=
			
			 
			
			1
			
			)
			
			,
			
			
    
			
			
			GlobalMeanPool
			
			(
			
			)
			
			,
			
			
    
			
			
			Flux
			
			.
			
			
			flatten
			
			,
			
			
    
			
			
			Dense
			
			(
			
			32
			
			,
			
			 
			
			10
			
			)
			
			,
			
			

			
			)

			Chain(Conv((3, 3), 1=>16, relu), Conv((3, 3), 16=>32, relu), GlobalMeanPool(), flatten, Dense(32, 10))

We’ll use categorical cross entropy as a loss function and ADAM as an optimizer .


			
			
			
			lossfn
			
			 
			
			=
			
			 
			
			
			
			Flux
			
			.
			
			
			Losses
			
			.
			
			
			logitcrossentropy
			
			
			

			
			
			optimizer
			
			 
			
			=
			
			 
			
			
			
			Flux
			
			.
			
			
			ADAM
			
			(
			
			)
			
			;

Now we’re ready to create a Learner . At this point you can also add any callbacks, like ToGPU to run the training on your GPU if you have one available. Some callbacks are also included by default .

Since we’re classifying digits, we also use the Metrics callback to track the accuracy of the model’s predictions:


			
			
			
			learner
			
			 
			
			=
			
			 
			
			
	
		
			Learner
			
			(
			
			model
			
			,
			
			 
			
			lossfn
			
			
			;
			
			 
			
			
			callbacks
			
			=
			
			
			[
			
			
	
		
			Metrics
			
			(
			
	
		
			accuracy
			
			)
			
			]
			
			,
			
			 
			
			optimizer
			
			)

			Learner()

Training

With a Learner inplace, training is as simple as calling fit! (learner, nepochs, dataiters) .


			
			
			
			
	
		
			FluxTraining
			
			.
			
			
	
		
			fit!
			
			(
			
			learner
			
			,
			
			 
			
			10
			
			,
			
			 
			
			
			(
			
			trainiter
			
			,
			
			 
			
			validiter
			
			)
			
			)

			Epoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:46
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   1.0 │ 2.04939 │  0.25204 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   1.0 │ 1.70353 │   0.3821 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:19
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   2.0 │ 1.58615 │  0.44849 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   2.0 │ 1.44792 │  0.50544 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   3.0 │ 1.36495 │  0.57273 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   3.0 │ 1.25941 │  0.59525 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:20
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   4.0 │ 1.18935 │  0.64891 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬────────┬──────────┐
│           Phase  Epoch    Loss  Accuracy │
├─────────────────┼───────┼────────┼──────────┤
│ ValidationPhase │   4.0 │ 1.1076 │  0.66347 │
└─────────────────┴───────┴────────┴──────────┘
Epoch 5 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:19
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   5.0 │ 1.05506 │  0.69386 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 5 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   5.0 │ 0.99203 │  0.70275 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 6 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   6.0 │ 0.95282 │  0.72533 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 6 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   6.0 │ 0.90209 │  0.73058 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 7 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:19
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   7.0 │ 0.87621 │  0.74563 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 7 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   7.0 │ 0.83402 │  0.74781 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 8 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   8.0 │ 0.81399 │  0.76282 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 8 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   8.0 │ 0.77623 │  0.76568 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 9 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   9.0 │ 0.76236 │  0.77835 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 9 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   9.0 │ 0.72606 │  0.78079 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 10 TrainingPhase(): 100%|█████████████████████████| Time: 0:00:18
┌───────────────┬───────┬─────────┬──────────┐
│         Phase  Epoch     Loss  Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │  10.0 │ 0.71684 │  0.79175 │
└───────────────┴───────┴─────────┴──────────┘
Epoch 10 ValidationPhase(): 100%|███████████████████████| Time: 0:00:02
┌─────────────────┬───────┬─────────┬──────────┐
│           Phase  Epoch     Loss  Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │  10.0 │ 0.68353 │  0.79449 │
└─────────────────┴───────┴─────────┴──────────┘

			Learner()
Backlinks