Image segmentation

In the quickstart section, you saw a short example of how to train an image segmentation model. In this tutorial, we'll recreate that using the mid-level APIs.


			
			
			
			using
			
			 

	
			FastAI
			,
			
			 
			FastVision
			,
			
			 

	
			Flux
			,
			
			 
			Metalhead
			
			

			
			import
			
			 
			CairoMakie
			;
			 
			
			
			CairoMakie
			.
			
			activate!
			(
			
			type
			=
			
			"
			png
			"
			)

In image segmentation, instead of assigning a class to a whole image as in image classification, we want to classify each pixel of an image. We'll use the CamVid dataset which contains street images with every pixel annotated as one of 32 classes.

Inside the dataset folder, we find images/, a subfolder with all the input images, labels/, a subfolder containing the segmentation masks (saved as images) and codes.txt which contains the class names:


			
			
			
			dir
			 
			=
			 
			
			load
			(
			
			

	
			datasets
			(
			)
			[
			
			"
			camvid
			"
			]
			)
			

			
			readdir
			(
			dir
			)

			4-element Vector{String}:
 "codes.txt"
 "images"
 "labels"
 "valid.txt"

Every line in codes.txt corresponds to one class, so let's load it:


			
			
			
			classes
			 
			=
			 
			
			readlines
			(
			
			open
			(
			
			joinpath
			(
			dir
			,
			 
			
			"
			codes.txt
			"
			)
			)
			)

			32-element Vector{String}:
 "Animal"
 "Archway"
 "Bicyclist"
 "Bridge"
 "Building"
 "Car"
 "CartLuggagePram"
 "Child"
 "Column_Pole"
 "Fence"
 "LaneMkgsDriv"
 "LaneMkgsNonDriv"
 "Misc_Text"
 ⋮
 "SignSymbol"
 "Sky"
 "SUVPickupTruck"
 "TrafficCone"
 "TrafficLight"
 "Train"
 "Tree"
 "Truck_Bus"
 "Tunnel"
 "VegetationMisc"
 "Void"
 "Wall"

To create our data containers we'll use loadfolderdata, which creates data containers from folder contents. The function loadfn is applied to the file paths and we use loadfile to loadmask to get images and masks.


			
			
			
			images
			 
			=
			 
			
			

	
			Datasets
			.
			

	
			loadfolderdata
			(
			
    
			
			joinpath
			(
			dir
			,
			 
			
			"
			images
			"
			)
			,
			
    
			
			filterfn
			=
			
			FastVision
			.
			

	
			isimagefile
			,
			
    
			
			loadfn
			=

	
			loadfile
			)
			

			

			
			masks
			 
			=
			 
			
			

	
			Datasets
			.
			

	
			loadfolderdata
			(
			
    
			
			joinpath
			(
			dir
			,
			 
			
			"
			labels
			"
			)
			,
			
    
			
			filterfn
			=
			
			FastVision
			.
			

	
			isimagefile
			,
			
    
			
			loadfn
			=
			
			f
			 
			->
			 
			
			
			FastVision
			.
			

	
			loadmask
			(
			f
			,
			 
			classes
			)
			)
			

			

			
			data
			 
			=
			 
			
			(
			images
			,
			 
			masks
			)

			(mapobs(loadfile, ObsView(::MLDatasets.FileDataset{typeof(identity), String}, ::Vector{Int64})), mapobs(#1, ObsView(::MLDatasets.FileDataset{typeof(identity), String}, ::Vector{Int64})))

Now we can get an observation:


			
			
			
			
			
			image
			,
			 
			mask
			 
			=
			
			 
			sample
			 
			=
			 
			

	
			getobs
			(
			data
			,
			 
			1
			)
			;

			
			
			image

Each mask is an array with elements the same type as our classes (here String), but efficiently stored using integers under the hood.


			
			
			
			view
			(
			mask
			,
			 
			
			50
			:
			55
			,
			 
			
			50
			:
			55
			)

			6×6 view(::IndirectArrays.IndirectArray{String, 2, UInt8, Base.ReinterpretArray{UInt8, 2, ColorTypes.Gray{FixedPointNumbers.N0f8}, Matrix{ColorTypes.Gray{FixedPointNumbers.N0f8}}, false}, Vector{String}}, 50:55, 50:55) with eltype String:
 "Building"  "Building"  "Building"  "Building"  "Building"  "Building"
 "Building"  "Building"  "Building"  "Building"  "Building"  "Building"
 "Building"  "Building"  "Building"  "Building"  "Building"  "Building"
 "Building"  "Building"  "Building"  "Building"  "Building"  "Building"
 "Building"  "Building"  "Building"  "Building"  "Building"  "Building"
 "Building"  "Building"  "Building"  "Building"  "Building"  "Building"

Next we need to create a learning task for image segmentation. This means using images to predict masks, so we'll use the Image and Mask blocks as input and target. Since the dataset is 2D, we'll use 2-dimensional blocks.


			
			
			
			task
			 
			=
			 
			

	
			SupervisedTask
			(
			
    
			
			(
			
			

	
			Image
			{
			2
			}
			(
			)
			,
			 
			
			

	
			Mask
			{
			2
			}
			(
			classes
			)
			)
			,
			
    
			
			(
			
        
			

	
			ProjectiveTransforms
			(
			
			(
			128
			,
			 
			128
			)
			)
			,
			
        
			

	
			ImagePreprocessing
			(
			)
			,
			
        
			

	
			OneHot
			(
			)
			
    
			)
			

			)

			SupervisedTask(Image{2} -> Mask{2, String})

The encodings passed in transform samples into formats suitable as inputs and outputs for a model, and also allow decoding model outputs to get back to our target format: an array of class labels for every pixel.

Let's check that samples from the created data container conform to the blocks of the learning task:


			
			
			

	
			checkblock
			(
			
			
			task
			.
			
			blocks
			.
			
			sample
			,
			 
			sample
			)

			true

We can ascertain the encodings work as expected by creating a batch of encoded data and visualizing it. Here the segmentation masks are color-coded and overlayed on top of the image.


			
			
			
			
			xs
			,
			 
			ys
			 
			=
			 
			
			

	
			FastAI
			.
			

	
			makebatch
			(
			task
			,
			 
			data
			,
			 
			
			1
			:
			3
			)
			

			

	
			showbatch
			(
			task
			,
			 
			
			(
			xs
			,
			 
			ys
			)
			)

We can use describetask to get more information about learning tasks created through the data block API. We see which representations our data goes through and which encodings transform which parts.


			
			
			

	
			describetask
			(
			task
			)

With a task and a matching data container, the only thing we need before we can create a Learner is a backbone architecture to build the segmentation model from. We'll use a slightly modified ResNet, but you can use any convolutional architecture.

We'll use taskmodel to construct the model from the backbone. Since we want mask outputs, the intermediate feature representation needs to be scaled back up. Based on the Blocks we built our task from, taskmodel knows that it needs to build a mapping ImageTensor{2} -> OneHotTensor{2} and constructs a U-Net model.


			
			
			
			backbone
			 
			=
			 
			
			
			
			
			Metalhead
			.
			
			ResNet
			(
			34
			)
			.
			
			layers
			[
			
			1
			:
			
			end
			-
			1
			]
			
			

			
			model
			 
			=
			 
			

	
			taskmodel
			(
			task
			,
			 
			backbone
			)
			;

In a similar vein, tasklossfn creates a loss function suitable for comparing model outputs and encoded targets.


			
			
			
			lossfn
			 
			=
			 
			

	
			tasklossfn
			(
			task
			)

			_segmentationloss (generic function with 1 method)

Next we turn our data container into training and validation data loaders that will iterate over batches of encoded data and construct a Learner.


			
			
			
			
			traindl
			,
			 
			validdl
			 
			=
			 
			

	
			taskdataloaders
			(
			data
			,
			 
			task
			,
			 
			16
			)
			

			
			optimizer
			 
			=
			 
			

	
			Adam
			(
			)
			

			
			learner
			 
			=
			 
			

	
			Learner
			(
			model
			,
			 
			lossfn
			
			;
			 
			
			data
			=
			
			(
			traindl
			,
			 
			validdl
			)
			,
			 
			optimizer
			,
			 
			
			callbacks
			=
			
			[
			

	
			ToGPU
			(
			)
			]
			)

			Learner()

Note that we could also have used tasklearner which is a shorthand that calls taskdataloaders and taskmodel for us.

Now let's train the model:


			
			
			

	
			fitonecycle!
			(
			learner
			,
			 
			10
			,
			 
			0.033
			)

			Epoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:04:46
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   1.0 │ 2.82042 │
└───────────────┴───────┴─────────┘
Epoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:10
┌─────────────────┬───────┬─────────┐
│           Phase  Epoch     Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │   1.0 │ 2384.78 │
└─────────────────┴───────┴─────────┘
Epoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   2.0 │ 1.62247 │
└───────────────┴───────┴─────────┘
Epoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│           Phase  Epoch     Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │   2.0 │ 1.45235 │
└─────────────────┴───────┴─────────┘
Epoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   3.0 │ 1.24695 │
└───────────────┴───────┴─────────┘
Epoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬────────┐
│           Phase  Epoch    Loss │
├─────────────────┼───────┼────────┤
│ ValidationPhase │   3.0 │ 1.2517 │
└─────────────────┴───────┴────────┘
Epoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   4.0 │ 1.18153 │
└───────────────┴───────┴─────────┘
Epoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│           Phase  Epoch     Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │   4.0 │ 2.41649 │
└─────────────────┴───────┴─────────┘
Epoch 5 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   5.0 │ 1.15426 │
└───────────────┴───────┴─────────┘
Epoch 5 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│           Phase  Epoch     Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │   5.0 │ 1.23344 │
└─────────────────┴───────┴─────────┘
Epoch 6 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   6.0 │ 1.08371 │
└───────────────┴───────┴─────────┘
Epoch 6 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│           Phase  Epoch     Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │   6.0 │ 1.20521 │
└─────────────────┴───────┴─────────┘
Epoch 7 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:05
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   7.0 │ 1.07083 │
└───────────────┴───────┴─────────┘
Epoch 7 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│           Phase  Epoch     Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │   7.0 │ 1.13855 │
└─────────────────┴───────┴─────────┘
Epoch 8 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:05
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   8.0 │ 1.04082 │
└───────────────┴───────┴─────────┘
Epoch 8 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│           Phase  Epoch     Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │   8.0 │ 1.02529 │
└─────────────────┴───────┴─────────┘
Epoch 9 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   9.0 │ 1.01765 │
└───────────────┴───────┴─────────┘
Epoch 9 ValidationPhase(): 100%|████████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│           Phase  Epoch     Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │   9.0 │ 0.94789 │
└─────────────────┴───────┴─────────┘
Epoch 10 TrainingPhase(): 100%|█████████████████████████| Time: 0:00:06
┌───────────────┬───────┬─────────┐
│         Phase  Epoch     Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │  10.0 │ 0.96513 │
└───────────────┴───────┴─────────┘
Epoch 10 ValidationPhase(): 100%|███████████████████████| Time: 0:00:01
┌─────────────────┬───────┬─────────┐
│           Phase  Epoch     Loss │
├─────────────────┼───────┼─────────┤
│ ValidationPhase │  10.0 │ 0.90619 │
└─────────────────┴───────┴─────────┘

And look at the results on a batch of validation data:


			
			
			

	
			showoutputs
			(
			task
			,
			 
			learner
			
			;
			 
			
			n
			 
			=
			 
			4
			)