Saving and loading models for inference

In the end, we train models because we want to use them for inference, that is, using them to generate predictions on new targets. The general formula for doing this in FastAI.jl is to first train a model for a task, for example using fitonecycle! or finetune! and then save the model and the learning task configuration to a file using savetaskmodel. In another session you can then use loadtaskmodel to load both. Since the learning task contains all preprocessing logic we can then use predict and predictbatch to generate predictions for new inputs.

Let's fine-tune an image classification model (see here for more info) and go through that process.


			
			
			
			using
			
			 

	
			FastAI
			,
			
			 
			FastVision
			,
			
			 
			Metalhead

			
			
			
			dir
			 
			=
			 
			
			joinpath
			(
			
			load
			(
			
			

	
			datasets
			(
			)
			[
			
			"
			dogscats
			"
			]
			)
			,
			 
			
			"
			train
			"
			)
			

			
			data
			 
			=
			 
			

	
			loadfolderdata
			(
			dir
			,
			 
			
			filterfn
			=
			
			FastVision
			.
			

	
			isimagefile
			,
			 
			
			loadfn
			=
			
			(

	
			loadfile
			,
			 

	
			parentname
			)
			)
			

			
			classes
			 
			=
			 
			
			unique
			(
			

	
			eachobs
			(
			
			data
			[
			2
			]
			)
			)
			

			
			task
			 
			=
			 
			

	
			BlockTask
			(
			
    
			
			(
			
			

	
			Image
			{
			2
			}
			(
			)
			,
			 
			

	
			Label
			(
			classes
			)
			)
			,
			
    
			
			(
			
        
			

	
			ProjectiveTransforms
			(
			
			(
			196
			,
			 
			196
			)
			)
			,
			
        
			

	
			ImagePreprocessing
			(
			)
			,
			
        
			

	
			OneHot
			(
			)
			
    
			)
			

			)
			

			

			
			backbone
			 
			=
			 
			
			
			
			
			Metalhead
			.
			
			ResNet50
			(
			)
			.
			
			layers
			[
			
			1
			:
			
			end
			-
			1
			]
			

			
			learner
			 
			=
			 
			

	
			tasklearner
			(
			task
			,
			 
			data
			
			;
			 
			
			backbone
			=
			backbone
			,
			 
			
			callbacks
			=
			
			[
			

	
			ToGPU
			(
			)
			,
			 
			

	
			Metrics
			(

	
			accuracy
			)
			]
			)

			
			
			

	
			fitonecycle!
			(
			learner
			,
			 
			10
			,
			 
			0.01
			)

Now we can save the model using savetaskmodel.


			
			
			

	
			savetaskmodel
			(
			
			"
			catsdogs.jld2
			"
			,
			 
			task
			,
			 
			
			learner
			.
			
			model
			,
			 
			
			force
			 
			=
			 
			true
			)

In another session we can now use loadtaskmodel to load both model and learning task from the file. Since the model weights are transferred to the CPU before being saved, we need to move them to the GPU manually if we want to use that for inference.


			
			
			
			
			task
			,
			 
			model
			 
			=
			 
			

	
			loadtaskmodel
			(
			
			"
			catsdogs.jld2
			"
			)
			
			

			
			model
			 
			=
			 
			

	
			gpu
			(
			model
			)
			;

Finally, let's select 9 random images from the dataset and see if the model classifies them correctly:


			
			
			# use it for inference
			

			
			
			x
			,
			 
			y
			 
			=
			
			 

			samples
			 
			=
			 
			
			[
			
			

	
			getobs
			(
			data
			,
			 
			i
			)
			 
			for
			
			 
			i
			 
			in
			 
			
			rand
			(
			
			1
			:
			

	
			numobs
			(
			data
			)
			,
			 
			9
			)
			]
			

			
			images
			 
			=
			 
			
			[
			
			
			sample
			[
			1
			]
			 
			for
			
			 
			sample
			 
			in
			 
			samples
			]
			

			
			labels
			 
			=
			 
			
			[
			
			
			sample
			[
			2
			]
			 
			for
			
			 
			sample
			 
			in
			 
			samples
			]
			

			
			preds
			 
			=
			 
			

	
			predictbatch
			(
			task
			,
			 
			model
			,
			 
			images
			
			;
			 
			
			device
			 
			=
			 

	
			gpu
			,
			 
			
			context
			 
			=
			 
			

	
			Validation
			(
			)
			)

			9-element Vector{String}:
 "dogs"
 "dogs"
 "dogs"
 "dogs"
 "cats"
 "dogs"
 "dogs"
 "dogs"
 "cats"

			
			
			
			acc
			 
			=
			
			 
			
			sum
			(
			
			labels
			 
			.==
			 
			preds
			)
			 
			/
			 
			
			length
			(
			preds
			)

			1.0

			
			
			
			import
			
			 
			FastMakie
			,
			
			 
			CairoMakie
			

			
			

	
			FastAI
			.
			

	
			showsamples
			(
			task
			,
			 
			
			collect
			(
			
			zip
			(
			images
			,
			 
			preds
			)
			)
			)