In the quickstart section, you saw a short example of how to train an image segmentation model. In this tutorial, we'll recreate that using the mid-level APIs.
In image segmentation, instead of assigning a class to a whole image as in image classification, we want to classify each pixel of an image. We'll use the CamVid dataset which contains street images with every pixel annotated as one of 32 classes.
Inside the dataset folder, we find
images/
, a subfolder with all the input images,
labels/
, a subfolder containing the segmentation masks (saved as images) and
codes.txt
which contains the class names:
dir
=
load
(
datasets
(
)
[
"
camvid
"
]
)
readdir
(
dir
)
4-element Vector{String}:
"codes.txt"
"images"
"labels"
"valid.txt"
Every line in
codes.txt
corresponds to one class, so let's load it:
classes
=
readlines
(
open
(
joinpath
(
dir
,
"
codes.txt
"
)
)
)
32-element Vector{String}:
"Animal"
"Archway"
"Bicyclist"
"Bridge"
"Building"
"Car"
"CartLuggagePram"
"Child"
"Column_Pole"
"Fence"
"LaneMkgsDriv"
"LaneMkgsNonDriv"
"Misc_Text"
⋮
"SignSymbol"
"Sky"
"SUVPickupTruck"
"TrafficCone"
"TrafficLight"
"Train"
"Tree"
"Truck_Bus"
"Tunnel"
"VegetationMisc"
"Void"
"Wall"
To create our data containers we'll use
loadfolderdata
, which creates data containers from folder contents. The function
loadfn
is applied to the file paths and we use
loadfile
to
loadmask
to get images and masks.
images
=
Datasets
.
loadfolderdata
(
joinpath
(
dir
,
"
images
"
)
,
filterfn
=
FastVision
.
isimagefile
,
loadfn
=
loadfile
)
masks
=
Datasets
.
loadfolderdata
(
joinpath
(
dir
,
"
labels
"
)
,
filterfn
=
FastVision
.
isimagefile
,
loadfn
=
f
->
FastVision
.
loadmask
(
f
,
classes
)
)
data
=
(
images
,
masks
)
(mapobs(loadfile, ObsView(::MLDatasets.FileDataset{typeof(identity), String}, ::Vector{Int64})), mapobs(#1, ObsView(::MLDatasets.FileDataset{typeof(identity), String}, ::Vector{Int64})))
Now we can get an observation:
image
,
mask
=
sample
=
getobs
(
data
,
1
)
;
image
Each mask is an array with elements the same type as our classes (here
String
), but efficiently stored using integers under the hood.
view
(
mask
,
50
:
55
,
50
:
55
)
6×6 view(::IndirectArrays.IndirectArray{String, 2, UInt8, Base.ReinterpretArray{UInt8, 2, ColorTypes.Gray{FixedPointNumbers.N0f8}, Matrix{ColorTypes.Gray{FixedPointNumbers.N0f8}}, false}, Vector{String}}, 50:55, 50:55) with eltype String:
"Building" "Building" "Building" "Building" "Building" "Building"
"Building" "Building" "Building" "Building" "Building" "Building"
"Building" "Building" "Building" "Building" "Building" "Building"
"Building" "Building" "Building" "Building" "Building" "Building"
"Building" "Building" "Building" "Building" "Building" "Building"
"Building" "Building" "Building" "Building" "Building" "Building"
Next we need to create a learning task for image segmentation. This means using images to predict masks, so we'll use the
Image
and
Mask
blocks as input and target. Since the dataset is 2D, we'll use 2-dimensional blocks.
task
=
SupervisedTask
(
(
Image
{
2
}
(
)
,
Mask
{
2
}
(
classes
)
)
,
(
ProjectiveTransforms
(
(
128
,
128
)
)
,
ImagePreprocessing
(
)
,
OneHot
(
)
)
)
SupervisedTask(Image{2} -> Mask{2, String})
The encodings passed in transform samples into formats suitable as inputs and outputs for a model, and also allow decoding model outputs to get back to our target format: an array of class labels for every pixel.
Let's check that samples from the created data container conform to the blocks of the learning task:
checkblock
(
task
.
blocks
.
sample
,
sample
)
true
We can ascertain the encodings work as expected by creating a batch of encoded data and visualizing it. Here the segmentation masks are color-coded and overlayed on top of the image.
We can use
describetask
to get more information about learning tasks created through the data block API. We see which representations our data goes through and which encodings transform which parts.
describetask
(
task
)
With a
task
and a matching data container, the only thing we need before we can create a
Learner
is a backbone architecture to build the segmentation model from. We'll use a slightly modified ResNet, but you can use any convolutional architecture.
We'll use
taskmodel
to construct the model from the backbone. Since we want mask outputs, the intermediate feature representation needs to be scaled back up. Based on the
Block
s we built our task from,
taskmodel
knows that it needs to build a mapping
ImageTensor{2} -> OneHotTensor{2}
and constructs a U-Net model.
backbone
=
Metalhead
.
ResNet
(
34
)
.
layers
[
1
:
end
-
1
]
model
=
taskmodel
(
task
,
backbone
)
;
In a similar vein,
tasklossfn
creates a loss function suitable for comparing model outputs and encoded targets.
lossfn
=
tasklossfn
(
task
)
_segmentationloss (generic function with 1 method)
Next we turn our data container into training and validation data loaders that will iterate over batches of encoded data and construct a
Learner
.
traindl
,
validdl
=
taskdataloaders
(
data
,
task
,
16
)
optimizer
=
Adam
(
)
learner
=
Learner
(
model
,
lossfn
;
data
=
(
traindl
,
validdl
)
,
optimizer
,
callbacks
=
[
ToGPU
(
)
]
)
Learner()
Note that we could also have used
tasklearner
which is a shorthand that calls
taskdataloaders
and
taskmodel
for us.
Now let's train the model:
fitonecycle!
(
learner
,
10
,
0.033
)
Epoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:04:46
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 1.0 │ 2.82042 │
└───────────────┴───────┴─────────┘
Epoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:10
┌─────────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │ 1.0 │ 2384.78 │
└─────────────────┴───────┴─────────┘
Epoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 2.0 │ 1.62247 │
└───────────────┴───────┴─────────┘
Epoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │ 2.0 │ 1.45235 │
└─────────────────┴───────┴─────────┘
Epoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 3.0 │ 1.24695 │
└───────────────┴───────┴─────────┘
Epoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼────────┤
│ ValidationPhase │ 3.0 │ 1.2517 │
└─────────────────┴───────┴────────┘
Epoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 4.0 │ 1.18153 │
└───────────────┴───────┴─────────┘
Epoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │ 4.0 │ 2.41649 │
└─────────────────┴───────┴─────────┘
Epoch 5 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 5.0 │ 1.15426 │
└───────────────┴───────┴─────────┘
Epoch 5 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │ 5.0 │ 1.23344 │
└─────────────────┴───────┴─────────┘
Epoch 6 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 6.0 │ 1.08371 │
└───────────────┴───────┴─────────┘
Epoch 6 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │ 6.0 │ 1.20521 │
└─────────────────┴───────┴─────────┘
Epoch 7 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:05
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 7.0 │ 1.07083 │
└───────────────┴───────┴─────────┘
Epoch 7 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │ 7.0 │ 1.13855 │
└─────────────────┴───────┴─────────┘
Epoch 8 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:05
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 8.0 │ 1.04082 │
└───────────────┴───────┴─────────┘
Epoch 8 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │ 8.0 │ 1.02529 │
└─────────────────┴───────┴─────────┘
Epoch 9 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 9.0 │ 1.01765 │
└───────────────┴───────┴─────────┘
Epoch 9 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │ 9.0 │ 0.94789 │
└─────────────────┴───────┴─────────┘
Epoch 10 TrainingPhase(): 100%|█████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │ 10.0 │ 0.96513 │
└───────────────┴───────┴─────────┘
Epoch 10 ValidationPhase(): 100%|███████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│ Phase │ Epoch │ Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │ 10.0 │ 0.90619 │
└─────────────────┴───────┴─────────┘
And look at the results on a batch of validation data:
showoutputs
(
task
,
learner
;
n
=
4
)
The data block API can also be used for 3D segmentation, but the tutorial for that is still in the works.