learner.jl

YVIq2/src/ is a source file in module YVIq2

			
			
			
			mutable
			
			
			
			 
			
			struct
			
			 
			
	
		
			Callbacks
			
			
			
    
			
			
			cbs
			
			::
			
			Vector
			
			
    
			
			
			runner
			
			::
			
	
		
			CallbackRunner
			
			
    
			
			
			graph
			
			::
			
			SimpleDiGraph
			
			
    
			
			
			initialized
			
			::
			
			Bool
			
			

			
			end
			
			

			
			
			
	
		
			Callbacks
			
			(
			
			cbs
			
			,
			
			 
			
			
			runner
			
			=
			
			
	
		
			LinearRunner
			
			(
			
			)
			
			)
			
			 
			
			=
			
			 
			
			
	
		
			Callbacks
			
			(
			
			cbs
			
			,
			
			 
			
			runner
			
			,
			
			 
			
			
	
		
			callbackgraph
			
			(
			
			cbs
			
			)
			
			,
			
			 
			
			false
			
			)
			
			

			
			

			
			
			
	
		
			init!
			
			(
			
			
			cbs
			
			::
			
	
		
			Callbacks
			
			,
			
			 
			
			learner
			
			)
			
			 
			
			=
			
			 
			
			
			foreach
			
			(
			
			
			cb
			
			 
			
			->
			
			 
			
			
	
		
			init!
			
			(
			
			cb
			
			,
			
			 
			
			learner
			
			)
			
			,
			
			 
			
			
			cbs
			
			.
			
			
			cbs
			
			)
			
			

			
			

			
			

			
			
			mutable
			
			
			
			 
			
			struct
			
			 
			
	
		
			Learner
			
			
			
    
			
			model
			
			
    
			
			
			data
			
			::
			
	
		
			PropDict
			
			
    
			
			optimizer
			
			
    
			
			lossfn
			
			
    
			
			# this used to store `Flux.Params` but now stores the optimiser state
			
			
    
			
			# if an optim from Optimisers.jl is used
			
			
    
			
			params
			
			
    
			
			
			step
			
			::
			
	
		
			PropDict
			
			
    
			
			
			callbacks
			
			::
			
	
		
			Callbacks
			
			
    
			
			
			cbstate
			
			::
			
	
		
			PropDict
			
			

			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			Learner
			
			(
			
			model
			
			,
			
			 
			
			lossfn
			
			
			;
			
			 
			
			
			callbacks
			
			 
			
			=
			
			 
			
			
			[
			
			]
			
			,
			
			 
			
			
			data
			
			 
			
			=
			
			 
			
			
			(
			
			)
			
			,
			
			 
			
			
			optimizer
			
			 
			
			=
			
			 
			
			
			ADAM
			
			(
			
			)
			
			,
			
			 
			
			
			kwargs
			
			...
			
			)
			
			
			
    
			
			
			return
			
			 
			
			
	
		
			Learner
			
			(
			
			model
			
			,
			
			 
			
			data
			
			,
			
			 
			
			optimizer
			
			,
			
			 
			
			lossfn
			
			,
			
			 
			
			
			callbacks
			
			...
			
			
			;
			
			 
			
			
			kwargs
			
			...
			
			)
			
			

			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			Learner
			
			(
			
			
        
			
			model
			
			,
			
			 
			
			data
			
			,
			
			 
			
			optimizer
			
			,
			
			 
			
			lossfn
			
			,
			
			 
			
			
			callbacks
			
			::
			
			
			Vararg
			
			{
			
			
			<:
			
	
		
			Callback
			
			}
			
			
			;
			
			
        
			
			
			usedefaultcallbacks
			
			=
			
			true
			
			,
			
			 
			
			
			cbrunner
			
			=
			
			
	
		
			LinearRunner
			
			(
			
			)
			
			
    
			
			)
			
			
			
    
			
			
			callbacks
			
			 
			
			=
			
			 
			
			
			collect
			
			(
			
	
		
			Callback
			
			,
			
			 
			
			callbacks
			
			)
			
			

			
			
    
			
			
			if
			
			 
			
			usedefaultcallbacks
			
			
			
        
			
			
			for
			
			
			 
			
			cb
			
			 
			
			in
			
			 
			
			
	
		
			defaultcallbacks
			
			(
			
			)
			
			
			
            
			
			
			if
			
			 
			
			
			!
			
			
			any
			
			(
			
			
			
			typeof
			
			(
			
			cb
			
			)
			
			 
			
			.==
			
			 
			
			
			typeof
			
			.
			
			
			(
			
			callbacks
			
			)
			
			)
			
			
			
                
			
			
			push!
			
			(
			
			callbacks
			
			,
			
			 
			
			cb
			
			)
			
			
            
			
			end
			
			
        
			
			end
			
			
    
			
			end
			
			

			
			
    
			
			
			cbs
			
			 
			
			=
			
			 
			
			
	
		
			Callbacks
			
			(
			
			callbacks
			
			,
			
			 
			
			cbrunner
			
			)
			
			

			
			
    
			
			
			learner
			
			 
			
			=
			
			 
			
			
	
		
			Learner
			
			(
			
			
        
			
			model
			
			,
			
			
        
			
			
	
		
			_dataiters
			
			(
			
			data
			
			)
			
			,
			
			
        
			
			optimizer
			
			,
			
			
        
			
			lossfn
			
			,
			
			
        
			
			
	
		
			setupoptimstate
			
			(
			
			model
			
			,
			
			 
			
			optimizer
			
			)
			
			,
			
			
        
			
			
	
		
			PropDict
			
			(
			
			)
			
			,
			
			
        
			
			cbs
			
			,
			
			
        
			
			
	
		
			PropDict
			
			(
			
			)
			
			)
			
			
    
			
			
	
		
			init!
			
			(
			
			cbs
			
			,
			
			 
			
			learner
			
			)
			
			
    
			
			
			return
			
			 
			
			learner
			
			

			
			end
			
			

			
			

			
			

			
			
			
			
			Base
			
			.
			
			
			show
			
			(
			
			
			io
			
			::
			
			IO
			
			,
			
			 
			
			
			learner
			
			::
			
	
		
			Learner
			
			)
			
			 
			
			=
			
			 
			
			
			print
			
			(
			
			io
			
			,
			
			 
			
			"
			
			Learner()
			
			"
			
			)
			
			

			
			

			
			

			
			
			
			
	
		
			defaultcallbacks
			
			(
			
			)
			
			::
			
			
			Vector
			
			{
			
	
		
			AbstractCallback
			
			}
			
			 
			
			=
			
			 
			
			
			[
			
			
    
			
			
	
		
			ProgressPrinter
			
			(
			
			)
			
			,
			
			
    
			
			
	
		
			MetricsPrinter
			
			(
			
			)
			
			,
			
			
    
			
			
	
		
			StopOnNaNLoss
			
			(
			
			)
			
			,
			
			
    
			
			
	
		
			Recorder
			
			(
			
			)
			
			,
			
			
    
			
			
	
		
			Metrics
			
			(
			
			)
			
			,
			
			

			
			]

Callback handling


			
			
			
			
	
		
			handle
			
			(
			
			event
			
			,
			
			 
			
			learner
			
			,
			
			 
			
			phase
			
			)
			
			 
			
			=
			
			 
			
			
	
		
			handle
			
			(
			
			
			
			learner
			
			.
			
			
			callbacks
			
			.
			
			
			runner
			
			,
			
			 
			
			event
			
			,
			
			 
			
			phase
			
			,
			
			 
			
			learner
			
			)

Other


			
			
			
			
	
		
			phasedataiter
			
			(
			
			
			::
			
	
		
			AbstractTrainingPhase
			
			)
			
			 
			
			=
			
			 
			
			
			:
			
			training
			
			

			
			
			
	
		
			phasedataiter
			
			(
			
			
			::
			
	
		
			AbstractValidationPhase
			
			)
			
			 
			
			=
			
			 
			
			
			:
			
			validation
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			model!
			
			(
			
			learner
			
			,
			
			 
			
			model
			
			)
			
			
			
    
			
			
			
			learner
			
			.
			
			
			model
			
			 
			
			=
			
			 
			
			model
			
			
    
			
			
			
			learner
			
			.
			
			
			params
			
			 
			
			=
			
			 
			
			
	
		
			setupoptimstate
			
			(
			
			model
			
			,
			
			 
			
			
			learner
			
			.
			
			
			optimizer
			
			)
			
			

			
			end

Flux.jl optimisers store params , while Optimisers.jl store the result of setup


			
			
			
			
	
		
			setupoptimstate
			
			(
			
			model
			
			,
			
			 
			
			
			::
			
			
			
			Flux
			
			.
			
			
			Optimise
			
			.
			
			
			AbstractOptimiser
			
			)
			
			 
			
			=
			
			 
			
			
			
			Flux
			
			.
			
			
			params
			
			(
			
			model
			
			)

Optimisers.jl has no abstract supertype so we assume non-Flux optimisers conform to the Optimisers.jl interface.


			
			
			
			
	
		
			setupoptimstate
			
			(
			
			model
			
			,
			
			 
			
			optim
			
			)
			
			 
			
			=
			
			 
			
			
			
			Optimisers
			
			.
			
			
			setup
			
			(
			
			optim
			
			,
			
			 
			
			model
			
			)
			
			

			
			

			
			

			
			
			
	
		
			_dataiters
			
			(
			
			
			d
			
			::
			
	
		
			PropDict
			
			)
			
			 
			
			=
			
			 
			
			d
			
			

			
			
			
	
		
			_dataiters
			
			(
			
			
			t
			
			::
			
			NamedTuple
			
			)
			
			 
			
			=
			
			 
			
			
	
		
			PropDict
			
			(
			
			
			pairs
			
			(
			
			t
			
			)
			
			)
			
			

			
			
			function
			
			 
			
			
	
		
			_dataiters
			
			(
			
			
			t
			
			::
			
			Tuple
			
			)
			
			
			
    
			
			
			if
			
			
			 
			
			
			length
			
			(
			
			t
			
			)
			
			 
			
			==
			
			 
			
			0
			
			
			
        
			
			
			return
			
			 
			
			
	
		
			PropDict
			
			(
			
			
			
			Dict
			
			{
			
			Symbol
			
			,
			
			Any
			
			}
			
			(
			
			)
			
			)
			
			
    
			
			
			elseif
			
			
			
			 
			
			
			length
			
			(
			
			t
			
			)
			
			 
			
			==
			
			 
			
			1
			
			
			
        
			
			
			return
			
			 
			
			
	
		
			_dataiters
			
			(
			
			(
			
			
			training
			
			 
			
			=
			
			 
			
			
			t
			
			[
			
			1
			
			]
			
			)
			
			)
			
			
    
			
			
			elseif
			
			
			
			 
			
			
			length
			
			(
			
			t
			
			)
			
			 
			
			==
			
			 
			
			2
			
			
			
        
			
			
			return
			
			 
			
			
	
		
			_dataiters
			
			(
			
			
			(
			
			
			training
			
			 
			
			=
			
			 
			
			
			t
			
			[
			
			1
			
			]
			
			,
			
			 
			
			
			validation
			
			 
			
			=
			
			 
			
			
			t
			
			[
			
			2
			
			]
			
			)
			
			)
			
			
    
			
			else
			
			
			
        
			
			
			error
			
			(
			
			"
			
			Please pass a `NamedTuple` or `PropDict` as `data`.
			
			"
			
			)
			
			
    
			
			end
			
			

			
			end