predict.jl

FastAI/tasks/predict.jl is a source file in module FastAI

			
			
			
			
			
			"""
			

			    predict(task, model, input[; device, context])

			

			Predict a `target` from `input` using `model`. Optionally apply function `device`

			to `x` before passing to `model` and use `context` instead of the default

			context [`Inference`](#).

			"""
			

			
			function
			 
			

	
			predict
			(
			task
			,
			 
			model
			,
			 
			input
			
			;
			 
			
			device
			 
			=
			 

	
			cpu
			,
			 
			
			undevice
			 
			=
			 

	
			cpu
			,
			 
			
			context
			 
			=
			 
			

	
			Inference
			(
			)
			)
			
			
    
			
			if
			 
			

	
			shouldbatch
			(
			task
			)
			
			
        
			
			return
			
			 
			

	
			predictbatch
			(
			task
			,
			
                            
			model
			,
			
                            
			
			[
			input
			]
			
			;
			
                            
			
			device
			 
			=
			 
			device
			,
			
                            
			
			undevice
			 
			=
			 
			undevice
			,
			
                            
			
			context
			 
			=
			 
			context
			)
			 
			|>
			 
			only
			
    
			else
			
			
        
			
			return
			 
			

	
			decodeypred
			(
			task
			,
			
                           
			context
			,
			
                           
			
			undevice
			(
			
			model
			(
			
			device
			(
			

	
			encodeinput
			(
			task
			,
			 
			context
			,
			 
			input
			)
			)
			)
			)
			)
			
    
			end
			

			end
			

			

			
			
			
			"""
			

			    predictbatch(task, model, inputs[; device, context])

			

			Predict `targets` from a vector of `inputs` using `model` by batching them.

			Optionally apply function `device` to batch before passing to `model` and

			use `context` instead of the default [`Inference`](#).

			"""
			

			
			function
			 
			

	
			predictbatch
			(
			task
			,
			
                      
			model
			,
			
                      
			inputs
			
			;
			
                      
			
			device
			 
			=
			 

	
			cpu
			,
			
                      
			
			undevice
			 
			=
			 

	
			cpu
			,
			
                      
			
			context
			 
			=
			 
			

	
			Inference
			(
			)
			)
			
			
    
			
			xs
			 
			=
			 
			
			device
			(
			
			

	
			MLUtils
			.
			
			batch
			(
			
			[
			
			
			copy
			(
			

	
			encodeinput
			(
			task
			,
			 
			context
			,
			 
			input
			)
			)
			 
			for
			
			 
			input
			 
			in
			 
			inputs
			]
			)
			)
			
    
			
			ŷs
			 
			=
			 
			
			undevice
			(
			
			model
			(
			xs
			)
			)
			
    
			
			targets
			 
			=
			 
			
			[
			
			

	
			decodeypred
			(
			task
			,
			 
			context
			,
			 
			)
			 
			for
			
			 
			 
			in
			 
			
			

	
			Datasets
			.
			

	
			unbatch
			(
			ŷs
			)
			]
			
    
			
			return
			 
			targets
			

			end