training.jl

YVIq2/src/ is a source file in module YVIq2

			
			
			
			function
			
			 
			
			
	
		
			epoch!
			
			(
			
			learner
			
			,
			
			 
			
			
			phase
			
			::
			
	
		
			Phase
			
			,
			
			 
			
			
			dataiter
			
			=
			
			
			
			learner
			
			.
			
			
			data
			
			[
			
			
	
		
			phasedataiter
			
			(
			
			phase
			
			)
			
			]
			
			)
			
			
			
    
			
			
			
	
		
			runepoch
			
			(
			
			learner
			
			,
			
			 
			
			phase
			
			)
			
			 
			
			do
			
			
			
			 
			
			_
			
			
			
        
			
			
			for
			
			
			 
			
			batch
			
			 
			
			in
			
			 
			
			dataiter
			
			
			
            
			
			
	
		
			step!
			
			(
			
			learner
			
			,
			
			 
			
			phase
			
			,
			
			 
			
			batch
			
			)
			
			
        
			
			end
			
			
    
			
			end
			
			

			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
	
		
			step!
			
			 
			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			step!
			
			(
			
			learner
			
			,
			
			 
			
			
			phase
			
			::
			
	
		
			TrainingPhase
			
			,
			
			 
			
			batch
			
			)
			
			
			
    
			
			
			
			xs
			
			,
			
			 
			
			ys
			
			 
			
			=
			
			 
			
			batch
			
			
    
			
			
			
	
		
			runstep
			
			(
			
			learner
			
			,
			
			 
			
			phase
			
			,
			
			 
			
			
			(
			
			
			;
			
			 
			
			
			xs
			
			=
			
			xs
			
			,
			
			 
			
			
			ys
			
			=
			
			ys
			
			)
			
			)
			
			 
			
			do
			
			
			
			 
			
	
		
			handle
			
			,
			
			 
			
			state
			
			
			

			
			
        
			
			
			
			state
			
			.
			
			
			grads
			
			 
			
			=
			
			 
			
			
			
	
		
			_gradient
			
			(
			
			
			learner
			
			.
			
			
			optimizer
			
			,
			
			 
			
			
			learner
			
			.
			
			
			model
			
			,
			
			 
			
			
			learner
			
			.
			
			
			params
			
			)
			
			 
			
			do
			
			
			
			 
			
			model
			
			
			
            
			
			
			
			state
			
			.
			
			
			ŷs
			
			 
			
			=
			
			 
			
			
			model
			
			(
			
			
			state
			
			.
			
			
			xs
			
			)
			
			
            
			
			
	
		
			handle
			
			(
			
			
	
		
			LossBegin
			
			(
			
			)
			
			)
			
			
            
			
			
			
			state
			
			.
			
			
			loss
			
			 
			
			=
			
			 
			
			
			
			learner
			
			.
			
			
			lossfn
			
			(
			
			
			state
			
			.
			
			
			ŷs
			
			,
			
			 
			
			
			state
			
			.
			
			
			ys
			
			)
			
			
            
			
			
	
		
			handle
			
			(
			
			
	
		
			BackwardBegin
			
			(
			
			)
			
			)
			
			
            
			
			
			return
			
			 
			
			
			state
			
			.
			
			
			loss
			
			
        
			
			end
			
			
        
			
			
	
		
			handle
			
			(
			
			
	
		
			BackwardEnd
			
			(
			
			)
			
			)
			
			
        
			
			
			
			
			learner
			
			.
			
			
			params
			
			,
			
			 
			
			
			learner
			
			.
			
			
			model
			
			 
			
			=
			
			 
			
			
	
		
			_update!
			
			(
			
			
            
			
			
			learner
			
			.
			
			
			optimizer
			
			,
			
			 
			
			
			learner
			
			.
			
			
			params
			
			,
			
			 
			
			
			learner
			
			.
			
			
			model
			
			,
			
			 
			
			
			state
			
			.
			
			
			grads
			
			)
			
			
    
			
			end
			
			

			
			end

Handle both old Flux.jl and new Optimisers.jl optimisers


			
			
			
			
	
		
			_gradient
			
			(
			
			f
			
			,
			
			 
			
			_
			
			,
			
			 
			
			m
			
			,
			
			 
			
			_
			
			)
			
			 
			
			=
			
			 
			
			
			
			gradient
			
			(
			
			f
			
			,
			
			 
			
			m
			
			)
			
			[
			
			1
			
			]
			
			

			
			
			
	
		
			_gradient
			
			(
			
			f
			
			,
			
			 
			
			
			::
			
			
			
			Flux
			
			.
			
			
			Optimise
			
			.
			
			
			AbstractOptimiser
			
			,
			
			 
			
			m
			
			,
			
			 
			
			
			ps
			
			::
			
			Params
			
			)
			
			 
			
			=
			
			 
			
			
			gradient
			
			(
			
			
			
			(
			
			)
			
			 
			
			->
			
			 
			
			
			f
			
			(
			
			m
			
			)
			
			,
			
			 
			
			ps
			
			)
			
			

			
			

			
			
			function
			
			 
			
			
	
		
			_update!
			
			(
			
			
			optimizer
			
			::
			
			
			
			Flux
			
			.
			
			
			Optimise
			
			.
			
			
			AbstractOptimiser
			
			,
			
			 
			
			params
			
			,
			
			 
			
			model
			
			,
			
			 
			
			grads
			
			)
			
			
			
    
			
			
			update!
			
			(
			
			optimizer
			
			,
			
			 
			
			params
			
			,
			
			 
			
			grads
			
			)
			
			
    
			
			
			return
			
			
			 
			
			params
			
			,
			
			 
			
			model
			
			

			
			end
			
			

			
			
			function
			
			 
			
			
	
		
			_update!
			
			(
			
			_
			
			,
			
			 
			
			st
			
			,
			
			 
			
			model
			
			,
			
			 
			
			grads
			
			)
			
			
			
    
			
			
			
			st
			
			,
			
			 
			
			model
			
			 
			
			=
			
			 
			
			
			
			Optimisers
			
			.
			
			
			update!
			
			(
			
			st
			
			,
			
			 
			
			model
			
			,
			
			 
			
			grads
			
			)
			
			
    
			
			
			return
			
			
			 
			
			st
			
			,
			
			 
			
			model
			
			

			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			step!
			
			(
			
			learner
			
			,
			
			 
			
			
			phase
			
			::
			
	
		
			ValidationPhase
			
			,
			
			 
			
			batch
			
			)
			
			
			
    
			
			
			
			xs
			
			,
			
			 
			
			ys
			
			 
			
			=
			
			 
			
			batch
			
			
    
			
			
			
	
		
			runstep
			
			(
			
			learner
			
			,
			
			 
			
			phase
			
			,
			
			 
			
			
			(
			
			
			;
			
			
			xs
			
			=
			
			xs
			
			,
			
			 
			
			
			ys
			
			=
			
			ys
			
			)
			
			)
			
			 
			
			do
			
			
			
			 
			
			_
			
			,
			
			 
			
			state
			
			
			
        
			
			
			
			state
			
			.
			
			
			ŷs
			
			 
			
			=
			
			 
			
			
			
			learner
			
			.
			
			
			model
			
			(
			
			
			state
			
			.
			
			
			xs
			
			)
			
			
        
			
			
			
			state
			
			.
			
			
			loss
			
			 
			
			=
			
			 
			
			
			
			learner
			
			.
			
			
			lossfn
			
			(
			
			
			state
			
			.
			
			
			ŷs
			
			,
			
			 
			
			
			state
			
			.
			
			
			ys
			
			)
			
			
    
			
			end
			
			

			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			runepoch
			
			(
			
			epochfn
			
			,
			
			 
			
			learner
			
			,
			
			 
			
			
			phase
			
			::
			
	
		
			Phase
			
			)
			
			
			
    
			
			
			
			handlefn
			
			(
			
			e
			
			)
			
			 
			
			=
			
			 
			
			
	
		
			handle
			
			(
			
			
			
			learner
			
			.
			
			
			callbacks
			
			.
			
			
			runner
			
			,
			
			 
			
			e
			
			,
			
			 
			
			phase
			
			,
			
			 
			
			learner
			
			)
			
			
    
			
			
			try
			
			
			
        
			
			
			handlefn
			
			(
			
			
	
		
			EpochBegin
			
			(
			
			)
			
			)
			
			
        
			
			
			epochfn
			
			(
			
			handlefn
			
			)
			
			
        
			
			
			handlefn
			
			(
			
			
	
		
			EpochEnd
			
			(
			
			)
			
			)
			
			
    
			
			catch
			
			 
			
			e
			
			
			
        
			
			
			if
			
			
			 
			
			e
			
			 
			
			isa
			
			 
			
	
		
			CancelEpochException
			
			
			
            
			
			
			@
			
			debug
			
			 
			
			"
			
			Epoch skipped
			
			"
			
			
			 
			
			error
			
			 
			
			=
			
			 
			
			e
			
			
            
			
			
			handlefn
			
			(
			
			
	
		
			EpochEnd
			
			(
			
			)
			
			)
			
			
        
			
			else
			
			
			
            
			
			
			rethrow
			
			(
			
			)
			
			
        
			
			end
			
			
    
			
			
			
			
			
			end
			
			

			
			end
			
			

			
			

			
			
			function
			
			 
			
			
	
		
			runstep
			
			(
			
			stepfn
			
			,
			
			 
			
			learner
			
			,
			
			 
			
			
			phase
			
			::
			
	
		
			Phase
			
			,
			
			 
			
			
			initialstate
			
			 
			
			=
			
			 
			
			
			(
			
			
			;
			
			)
			
			)
			
			
			
    
			
			
			state
			
			 
			
			=
			
			 
			
			
	
		
			PropDict
			
			(
			
			
			pairs
			
			(
			
			initialstate
			
			)
			
			)
			
			
    
			
			
			
			handlefn
			
			(
			
			e
			
			)
			
			 
			
			=
			
			 
			
			
	
		
			handle
			
			(
			
			
			
			learner
			
			.
			
			
			callbacks
			
			.
			
			
			runner
			
			,
			
			 
			
			e
			
			,
			
			 
			
			phase
			
			,
			
			 
			
			learner
			
			)
			
			
    
			
			
			try
			
			
			
        
			
			
			
			learner
			
			.
			
			
			step
			
			 
			
			=
			
			 
			
			state
			
			
        
			
			
			handlefn
			
			(
			
			
	
		
			StepBegin
			
			(
			
			)
			
			)
			
			
        
			
			
			stepfn
			
			(
			
			handlefn
			
			,
			
			 
			
			state
			
			)
			
			
        
			
			
			handlefn
			
			(
			
			
	
		
			StepEnd
			
			(
			
			)
			
			)
			
			
        
			
			
			return
			
			 
			
			state
			
			
    
			
			catch
			
			 
			
			e
			
			
			
        
			
			
			if
			
			
			 
			
			e
			
			 
			
			isa
			
			 
			
	
		
			CancelStepException
			
			
			
            
			
			
			@
			
			debug
			
			 
			
			"
			
			Step skipped
			
			"
			
			
			 
			
			error
			
			 
			
			=
			
			 
			
			e
			
			
        
			
			else
			
			
			
            
			
			
			rethrow
			
			(
			
			)
			
			
        
			
			end
			
			
    
			
			
			
			
			
			end
			
			
    
			
			
			return
			
			 
			
			state
			
			

			
			end

Utilities


			
			
			
			function
			
			 
			
			
	
		
			fit!
			
			(
			
			learner
			
			,
			
			 
			
			
			nepochs
			
			::
			
			Int
			
			,
			
			 
			
			
			(
			
			trainiter
			
			,
			
			 
			
			validiter
			
			)
			
			)
			
			
			
    
			
			
			for
			
			
			 
			
			i
			
			 
			
			in
			
			
			 
			
			1
			
			:
			
			nepochs
			
			
			
        
			
			
	
		
			epoch!
			
			(
			
			learner
			
			,
			
			 
			
			
	
		
			TrainingPhase
			
			(
			
			)
			
			,
			
			 
			
			trainiter
			
			)
			
			
        
			
			
	
		
			epoch!
			
			(
			
			learner
			
			,
			
			 
			
			
	
		
			ValidationPhase
			
			(
			
			)
			
			,
			
			 
			
			validiter
			
			)
			
			
    
			
			end
			
			

			
			end
			
			

			
			

			
			
			function
			
			 
			
			
	
		
			fit!
			
			(
			
			learner
			
			,
			
			 
			
			
			nepochs
			
			::
			
			Int
			
			)
			
			
			
    
			
			
	
		
			fit!
			
			(
			
			learner
			
			,
			
			 
			
			nepochs
			
			,
			
			 
			
			
			(
			
			
			
			learner
			
			.
			
			
			data
			
			.
			
			
			training
			
			,
			
			 
			
			
			
			learner
			
			.
			
			
			data
			
			.
			
			
			validation
			
			)
			
			)
			
			

			
			end