This tutorial is adapted from this tutorial in fast.ai's documentation. It tries to stay as close to the original as possible, but diverges where the APIs differ.
In this tutorial, we will see how to deal with a new type of task using the middle layer of the fastai library. The example we will use is a Siamese network, that takes two images and determine if they are of the same class or not. In particular we will see:
how to quickly get DataLoaders from a standard PyTorch Datasets
how to adapt this in a Transform to get some of the show features of fastai
how to add some new behavior to show_batch/show_results for a custom task
how to write a custom DataBlock
how to create your own model from a pretrained model
how to pass along a custom splitter to Learner to take advantage of transfer learning
UPDATE
Since we'll be implementing some image operations manually, we'll add the
Images
package:
using
Pkg
;
Pkg
.
add
(
"
Images
"
)
To make our data ready for training a model, we need to create data iterators for training and validation, for example using
DataLoader
. Usually, the first step is to create a
data container that is then wrapped inside a
DataLoader
. Unlike in fast.ai, FastAI.jl separates the loading part from the encoding part. In this case, loading means getting pairs of images and encoding includes preprocessing and augmenting them. We'll first create a data container that just loads pairs of images and later show how to apply transforms on top of that.
using
FastAI
import
CairoMakie
;
CairoMakie
.
activate!
(
type
=
"
png
"
)
First we'll use
datasets
to download and untar the dataset and then find all image files.
path
=
joinpath
(
load
(
datasets
(
)
[
"
oxford-iiit-pet
"
]
)
,
"
images
"
)
files
=
loadfolderdata
(
path
;
filterfn
=
isimagefile
)
files
[
1
]
"/home/lorenz/.julia/datadeps/fastai-oxford-iiit-pet/oxford-iiit-pet/images/Abyssinian_1.jpg"
We can open the first image and have a look at it. Note that array indices start at 1 in Julia.
image
=
loadfile
(
files
[
1
]
)
Let's wrap all the standard preprocessing (resize and conversion to tensor and reordering of the channels) in one helper function. Note some differences to Python here:
images by default are 2D-arrays of pixels, so we need to use
channelview
to get a 3D array with the color dimension expanded
pixel values are assumed to be between 0 and 1, so we do not need to divide them by 255
For flexibility, we also separate the loading (
loadfile
) from the transformations applied to the image.
using
Images
function
transform_image
(
image
,
sz
=
224
)
image_resized
=
imresize
(
convert
.
(
RGB
{
N0f8
}
,
image
)
,
(
sz
,
sz
)
)
a
=
permuteddimsview
(
channelview
(
image_resized
)
,
(
2
,
3
,
1
)
)
end
transform_image (generic function with 2 methods)
transform_image
(
loadfile
(
files
[
1
]
)
)
|>
summary
"224×224×3 PermutedDimsArray(reinterpret(reshape, N0f8, ::Array{RGB{N0f8},2}), (2, 3, 1)) with eltype N0f8"
We can see the label of our image is in the filename, before the last _ and some number. We can then use a regex expression to create a label function:
label_func
(
path
)
=
match
(
r
"
^(.*)_\d+\.jpg$
"
,
pathname
(
path
)
)
[
1
]
label_func
(
files
[
1
]
)
"Abyssinian"
Now let's gather all unique labels:
labels
=
map
(
label_func
,
files
)
length
(
unique
(
labels
)
)
37
We could now use
mapobs
to create a data container from our list of files. It applies a function like loading an image lazily and we can get single observations using
getobs
and the number of observtions with
numobs
. It is the same as a
torch.utils.data.Dataset
. For example, the following example creates a data container with tuples of an image and a category that could be used for image classification
data
=
mapobs
(
files
)
do
file
return
(
loadfile
(
file
)
,
label_func
(
file
)
)
end
@
show
numobs
(
data
)
image
,
label
=
getobs
(
data
,
1
)
nobs(data) = 7390
(RGB{N0f8}[RGB{N0f8}(0.11,0.149,0.106) RGB{N0f8}(0.11,0.149,0.106) … RGB{N0f8}(0.153,0.196,0.141) RGB{N0f8}(0.149,0.192,0.137); RGB{N0f8}(0.106,0.145,0.102) RGB{N0f8}(0.106,0.145,0.102) … RGB{N0f8}(0.157,0.2,0.145) RGB{N0f8}(0.149,0.192,0.137); … ; RGB{N0f8}(0.047,0.075,0.043) RGB{N0f8}(0.043,0.071,0.039) … RGB{N0f8}(0.059,0.09,0.047) RGB{N0f8}(0.059,0.09,0.047); RGB{N0f8}(0.047,0.075,0.043) RGB{N0f8}(0.043,0.071,0.039) … RGB{N0f8}(0.059,0.09,0.047) RGB{N0f8}(0.059,0.09,0.047)], "Abyssinian")
To create our Siamese datasets, however, we will need to create tuples of images for inputs and the target will be
true
if the images are of the same class,
false
otherwise.
First we'll shuffle the files and split them into a training and a validation set
using
Random
idxs
=
shuffle
(
1
:
length
(
files
)
)
cut
=
round
(
Int
,
0.8
*
length
(
idxs
)
)
trainidxs
,
valididxs
=
idxs
[
1
:
cut
]
,
idxs
[
cut
+
1
:
end
]
trainfiles
,
validfiles
=
files
[
trainidxs
]
,
files
[
valididxs
]
summary
.
(
(
trainfiles
,
validfiles
)
)
("5912-element Vector{String}", "1478-element Vector{String}")
Let's create a custom data container that returns pairs of indices and a Boolean indicating whether the label is the same. Half of the pairs will have the same label, and half will not. Additionally, during training the other image will be chosen randomly while for the validation the pairs will always be the same.
While you can get far with basic data containers like
loadfolderdata
and transformations like
mapobs
and
filterobs
, sometimes it's simplest to create a custom data container. You just need to implement
getobs
and
numobs
for your type, similar to how you would implement
__getindex__
and
__len__
for a PyTorch
Dataset
.
import
FastAI
.
MLUtils
struct
SiamesePairs
labels
same
other
valid
end
function
SiamesePairs
(
labels
;
valid
=
false
)
ulabels
=
unique
(
labels
)
same
=
Dict
(
label
=>
[
i
for
(
i
,
l
)
in
enumerate
(
labels
)
if
l
==
label
]
for
label
in
ulabels
)
other
=
Dict
(
label
=>
[
i
for
(
i
,
l
)
in
enumerate
(
labels
)
if
l
!=
label
]
for
label
in
ulabels
)
return
SiamesePairs
(
labels
,
same
,
other
,
valid
)
end
function
MLUtils
.
getobs
(
si
::
SiamesePairs
,
idx
::
Int
)
rng
=
si
.
valid
?
MersenneTwister
(
idx
)
:
Random
.
GLOBAL_RNG
if
rand
(
rng
)
>
0.5
return
(
(
idx
,
rand
(
rng
,
si
.
same
[
si
.
labels
[
idx
]
]
)
)
,
true
)
else
return
(
(
idx
,
rand
(
rng
,
si
.
other
[
si
.
labels
[
idx
]
]
)
)
,
false
)
end
end
MLUtils
.
numobs
(
si
::
SiamesePairs
)
=
length
(
si
.
labels
)
We can combine this data container that gives us pairs of indices with the files and map the loading and preprocessing functions over it to get a data container that is ready to be passed to a
DataLoader
.
function
siamesedata
(
files
;
valid
=
false
,
transformfn
=
identity
)
labels
=
map
(
label_func
,
files
)
si
=
SiamesePairs
(
labels
;
valid
=
valid
)
return
mapobs
(
si
)
do
obs
(
i
,
j
)
,
same
=
obs
image1
=
transformfn
(
loadfile
(
files
[
i
]
)
)
image2
=
transformfn
(
loadfile
(
files
[
j
]
)
)
return
(
(
image1
,
image2
)
,
same
)
end
end
siamesedata (generic function with 1 method)
traindata
=
siamesedata
(
trainfiles
;
transformfn
=
transform_image
)
validdata
=
siamesedata
(
validfiles
;
transformfn
=
transform_image
,
valid
=
true
)
;
We can see that each observation consists of two images and a Boolean:
summary
.
(
getobs
(
traindata
,
1
)
)
("Tuple{PermutedDimsArray{N0f8, 3, (2, 3, 1), (3, 1, 2), Base.ReinterpretArray{N0f8, 3, RGB{N0f8}, Matrix{RGB{N0f8}}, true}}, PermutedDimsArray{N0f8, 3, (2, 3, 1), (3, 1, 2), Base.ReinterpretArray{N0f8, 3, RGB{N0f8}, Matrix{RGB{N0f8}}, true}}}", "Bool")
To use the above data containers in training, we can simply pass them to a
DataLoader
.
traindl
=
DataLoader
(
shuffleobs
(
traindata
)
,
16
)
validdl
=
DataLoader
(
validdata
,
32
)
Next let's look at how we can extend this example to make use FastAI.jl's data augmentation, visualize our data and more.
This is where FastAI.jl's API diverges a bit from fast.ai's. Note how above, we made sure to separate the data container creation and loading from disk from the preprocessing that is applied to every observation. In FastAI.jl, the preprocessing or "encoding" is implemented through a learning task. Learning tasks contain any configuration and, beside data processing, have extensible functions for visualizations and model building. One advantage of this separation between loading and encoding is that the data container can easily be swapped out as long as it has observations suitable for the learning task (in this case a tuple of two images and a Boolean). It also makes it easy to export models and all the necessary configuration.
The easiest way to create learning tasks is using the data block API which should suit the very most of all use cases. It is also possible to directly
implement the lower-level
LearningTask
interface.
The best way to understand it is to use it, so let's build a learning task for Siamese image similarity. We specify the kinds of input and target data as
blocks. Here, we have two 2D images as input (
(Image{2}(), Image{2}())
) and a binary label (
Label([true, false])
) as output. We also pass in a tuple of encodings that describe how the data is transformed before being fed to a model. Here
ProjectiveTransforms
resizes the images to the same size,
ImagePreprocessing
converts the images to the right format and
OneHot
one-hot encodes the labels.
task
=
BlockTask
(
(
(
Image
{
2
}
(
)
,
Image
{
2
}
(
)
)
,
Label
(
[
true
,
false
]
)
)
,
(
ProjectiveTransforms
(
(
128
,
128
)
,
buffered
=
false
,
sharestate
=
false
)
,
ImagePreprocessing
(
buffered
=
false
)
,
OneHot
(
)
,
)
)
BlockTask(Tuple{Image{2}, Image{2}} -> Label{Bool})
We can get a better understanding of the representations the data goes through using
describetask
:
describetask
(
task
)
We can reuse all the code above for creating the data container, we just omit the preprocessing function.
taskdataloaders
constructs training and validation data loaders from data containers by mapping the task encoding over the data containers:
traindata
=
siamesedata
(
trainfiles
,
valid
=
true
)
validdata
=
siamesedata
(
validfiles
;
valid
=
true
)
;
traindl
,
valdl
=
taskdataloaders
(
traindata
,
validdata
,
task
,
32
)
(eachobsparallel(batchviewcollated() with 185 batches of size 32), eachobsparallel(batchviewcollated() with 24 batches of size 64))
The above is equivalent to lazily mapping the encoding over the data containers and creating a
DataLoader
from them.
traindl
=
DataLoader
(
mapobs
(
sample
->
encodesample
(
task
,
Training
(
)
,
sample
)
,
shuffleobs
(
traindata
)
)
,
4
)
validdl
=
DataLoader
(
mapobs
(
sample
->
encodesample
(
task
,
Validation
(
)
,
sample
)
,
validdata
)
,
8
)
By separating the data container preparation from the task-specific data preprocessing and augmentation, we were able to completely reuse the data container creation and quickly add FastAI.jl's data augmentation.
With the data iterators ready for training, we still need a model to train. For the Siamese similarity setting, a common model architecture is a encoder/feature extractor that is applied to both images and a head that transforms the concatenated image features, resulting in a similar/not similar categorical output.
Flux.jl makes it easy to create models in a similar fashion to PyTorch modules. The forward pass is implemented by overloading the call operator, and the backward pass is automatically generated.
using
FastAI
.
Flux
struct
SiameseModel
{
E
,
H
}
encoder
::
E
head
::
H
end
# tells Flux that this struct contains submodules in its fields
Flux
.
@
functor
SiameseModel
# forward pass
function
(
m
::
SiameseModel
)
(
(
xs1
,
xs2
)
)
return
m
.
head
(
cat
(
m
.
encoder
(
xs1
)
,
m
.
encoder
(
xs2
)
;
dims
=
3
)
)
end
We'll use a XResNet model for the encoder and the same head that is used for classification models. Using
Flux.outputsize
we can check how many channels the encoder outputs without having to evaluate the model:
using
Metalhead
encoder
=
Models
.
xresnet18
(
)
encoder
=
Metalhead
.
ResNet50
(
pretrain
=
true
)
.
layers
[
1
]
[
1
:
end
-
1
]
h
,
w
,
ch
,
b
=
Flux
.
outputsize
(
encoder
,
(
128
,
128
,
3
,
1
)
)
(4, 4, 2048, 1)
We need to double this feature count since the head gets the concatenated features from two images:
head
=
Models
.
visionhead
(
2
ch
,
2
)
Chain(Parallel(vcat, AdaptiveMeanPool((1, 1)), AdaptiveMaxPool((1, 1))), flatten, Chain(BatchNorm(8192), identity, Dense(8192, 512, relu; bias=false)), Chain(BatchNorm(512), identity, Dense(512, 2; bias=false)))
model
=
SiameseModel
(
encoder
,
head
)
;
Let's test the model works on a batch of training data:
2×32 CUDA.CuArray{Float32, 2}:
1.56911 2.57546 4.55453 2.17796 … 2.97762 2.73883 2.52148 3.14173
2.61323 3.30812 3.29609 4.31212 2.07954 2.65726 2.32586 3.6022
We'll use categorical crossentropy (for logits) as a loss function and ADAM as optimizer:
Adam(0.001, (0.9, 0.999), IdDict{Any, Any}())
Now we can create a
Learner
and start training following the usual process.
callbacks
=
[
ToGPU
(
)
,
Metrics
(
accuracy
)
]
learner
=
Learner
(
model
,
(
traindl
,
valdl
)
,
optimizer
,
lossfn
,
callbacks
...
)
Learner()
encoder
=
Metalhead
.
ResNet50
(
pretrain
=
true
)
.
layers
[
1
]
[
1
:
end
-
1
]
h
,
w
,
ch
,
b
=
Flux
.
outputsize
(
encoder
,
(
128
,
128
,
3
,
1
)
)
head
=
Models
.
visionhead
(
2
ch
,
2
)
model
=
SiameseModel
(
encoder
,
head
)
;
learner
=
Learner
(
model
,
(
traindl
,
valdl
)
,
Adam
(
0.01
)
,
lossfn
,
callbacks
...
)
Learner()
fitonecycle!
(
learner
,
10
,
4e-4
)
Epoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:33
Epoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:03
Epoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:34
Epoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:03
Epoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:34
Epoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:03
Epoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:34
Epoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:03
Epoch 5 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:34
Epoch 5 ValidationPhase(): 100%|████████████████████████| Time: 0:00:03
Epoch 6 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:34
Epoch 6 ValidationPhase(): 100%|████████████████████████| Time: 0:00:03
Epoch 7 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:34
Epoch 7 ValidationPhase(): 100%|████████████████████████| Time: 0:00:03
Epoch 8 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:34
Epoch 8 ValidationPhase(): 100%|████████████████████████| Time: 0:00:03
Epoch 9 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:33
Epoch 9 ValidationPhase(): 100%|████████████████████████| Time: 0:00:03
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 1.0 │ 1.05782 │ 0.49212 │
└───────────────┴───────┴─────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 1.0 │ 1.14344 │ 0.52018 │
└─────────────────┴───────┴─────────┴──────────┘
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 2.0 │ 1.18025 │ 0.51954 │
└───────────────┴───────┴─────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 2.0 │ 1.19179 │ 0.51063 │
└─────────────────┴───────┴─────────┴──────────┘
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 3.0 │ 0.85397 │ 0.56425 │
└───────────────┴───────┴─────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 3.0 │ 0.90825 │ 0.51324 │
└─────────────────┴───────┴─────────┴──────────┘
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 4.0 │ 0.70473 │ 0.59904 │
└───────────────┴───────┴─────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 4.0 │ 0.81478 │ 0.52344 │
└─────────────────┴───────┴─────────┴──────────┘
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 5.0 │ 0.65689 │ 0.63226 │
└───────────────┴───────┴─────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 5.0 │ 0.78461 │ 0.53559 │
└─────────────────┴───────┴─────────┴──────────┘
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 6.0 │ 0.61939 │ 0.65321 │
└───────────────┴───────┴─────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 6.0 │ 0.77865 │ 0.53754 │
└─────────────────┴───────┴─────────┴──────────┘
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 7.0 │ 0.60574 │ 0.66481 │
└───────────────┴───────┴─────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 7.0 │ 0.78052 │ 0.51302 │
└─────────────────┴───────┴─────────┴──────────┘
┌───────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │ 8.0 │ 0.58161 │ 0.69842 │
└───────────────┴───────┴─────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 8.0 │ 0.77444 │ 0.50738 │
└─────────────────┴───────┴─────────┴──────────┘
┌───────────────┬───────┬────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼────────┼──────────┤
│ TrainingPhase │ 9.0 │ 0.5571 │ 0.7161 │
└───────────────┴───────┴────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 9.0 │ 0.77094 │ 0.52148 │
└─────────────────┴───────┴─────────┴──────────┘
Epoch 10 TrainingPhase(): 100%|█████████████████████████| Time: 0:00:34
Epoch 10 ValidationPhase(): 100%|███████████████████████| Time: 0:00:03
┌───────────────┬───────┬────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├───────────────┼───────┼────────┼──────────┤
│ TrainingPhase │ 10.0 │ 0.5418 │ 0.73041 │
└───────────────┴───────┴────────┴──────────┘
┌─────────────────┬───────┬─────────┬──────────┐
│ Phase │ Epoch │ Loss │ Accuracy │
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │ 10.0 │ 0.77138 │ 0.51172 │
└─────────────────┴───────┴─────────┴──────────┘