training.jl

FluxTraining/training.jl is a source file in module FluxTraining

			
			
			
			
			
			"""
			

			    epoch!(learner, phase[, dataiter])

			

			Train `learner` for one epoch on `dataiter`. Iterates through

			`dataiter` and [`step!`](#)s for each batch/item.

			

			If no data iterator is passed in, use `learner.data[phasedataiter(phase)]`.

			

			## Extending

			

			The default implementation iterates over every batch in `dataiter`

			and calls [`step!`](#) for each. This behavior can be overloaded

			by implementing `epoch!(learner, ::MyPhase, dataiter)`.

			

			If you're implementing a custom `epoch!` method, it is recommended

			you make use of [`runepoch`](#) to get begin and end events as well

			as proper handling of [`CancelEpochException`](#)s.

			

			See the default implementation for reference.

			"""
			

			
			function
			 
			

	
			epoch!
			(
			learner
			,
			 
			
			phase
			::

	
			Phase
			,
			 
			
			dataiter
			=
			
			
			learner
			.
			
			data
			[
			

	
			phasedataiter
			(
			phase
			)
			]
			)
			
			
    
			
			

	
			runepoch
			(
			learner
			,
			 
			phase
			)
			 
			do
			
			 
			_
			
			
        
			
			for
			
			 
			batch
			 
			in
			 
			dataiter
			
			
            
			

	
			step!
			(
			learner
			,
			 
			phase
			,
			 
			batch
			)
			
        
			end
			
    
			end
			

			end
			

			

			

			
			
			
			"""
			

			    step!(learner, phase::Phase, batch)

			

			Run one step of training for `learner` on batch.

			Behavior is customized through `phase`.

			

			## Extending

			

			This is a required method for custom [`Phase`](#)s to implement.

			To implement `step!`, it is recommended you make use of [`runstep`](#)

			to get begin and end events as well as proper handling of

			[`CancelStepException`](#)s.

			

			See the implementations of [`TrainingPhase`](#) and [`ValidationPhase`](#)

			for reference.

			"""
			

			
			function
			 

	
			step!
			 
			end
			

			

			

			
			function
			 
			

	
			step!
			(
			learner
			,
			 
			
			phase
			::

	
			TrainingPhase
			,
			 
			batch
			)
			
			
    
			
			
			xs
			,
			 
			ys
			 
			=
			 
			batch
			
    
			
			

	
			runstep
			(
			learner
			,
			 
			phase
			,
			 
			
			(
			
			;
			 
			
			xs
			=
			xs
			,
			 
			
			ys
			=
			ys
			)
			)
			 
			do
			
			 

	
			handle
			,
			 
			state
			
			

			
        
			
			
			state
			.
			
			grads
			 
			=
			 
			
			
			_gradient
			(
			
			learner
			.
			
			optimizer
			,
			 
			
			learner
			.
			
			model
			,
			 
			
			learner
			.
			
			params
			)
			 
			do
			
			 
			model
			
			
            
			
			
			state
			.
			
			ŷs
			 
			=
			 
			
			model
			(
			
			state
			.
			
			xs
			)
			
            
			

	
			handle
			(
			

	
			LossBegin
			(
			)
			)
			
            
			
			
			state
			.
			
			loss
			 
			=
			 
			
			
			learner
			.
			
			lossfn
			(
			
			state
			.
			
			ŷs
			,
			 
			
			state
			.
			
			ys
			)
			
            
			

	
			handle
			(
			

	
			BackwardBegin
			(
			)
			)
			
            
			
			return
			 
			
			state
			.
			
			loss
			
        
			end
			
        
			

	
			handle
			(
			

	
			BackwardEnd
			(
			)
			)
			
        
			
			
			
			learner
			.
			
			params
			,
			 
			
			learner
			.
			
			model
			 
			=
			 
			
			_update!
			(
			
            
			
			learner
			.
			
			optimizer
			,
			 
			
			learner
			.
			
			params
			,
			 
			
			learner
			.
			
			model
			,
			 
			
			state
			.
			
			grads
			)
			
    
			end
			

			end

Handle both old Flux.jl and new Optimisers.jl optimisers


			
			
			
			
			_gradient
			(
			f
			,
			 
			_
			,
			 
			m
			,
			 
			_
			)
			 
			=
			 
			
			
			gradient
			(
			f
			,
			 
			m
			)
			[
			1
			]
			

			
			
			_gradient
			(
			f
			,
			 
			
			::
			
			

	
			Flux
			.
			
			Optimise
			.
			
			AbstractOptimiser
			,
			 
			m
			,
			 
			
			ps
			::
			Params
			)
			 
			=
			 
			
			gradient
			(
			
			
			(
			)
			 
			->
			 
			
			f
			(
			m
			)
			,
			 
			ps
			)
			

			

			
			function
			 
			
			_update!
			(
			
			optimizer
			::
			
			

	
			Flux
			.
			
			Optimise
			.
			
			AbstractOptimiser
			,
			 
			params
			,
			 
			model
			,
			 
			grads
			)
			
			
    
			
			update!
			(
			optimizer
			,
			 
			params
			,
			 
			grads
			)
			
    
			
			return
			
			 
			params
			,
			 
			model
			

			end
			

			
			function
			 
			
			_update!
			(
			_
			,
			 
			st
			,
			 
			model
			,
			 
			grads
			)
			
			
    
			
			
			st
			,
			 
			model
			 
			=
			 
			
			
			Optimisers
			.
			
			update!
			(
			st
			,
			 
			model
			,
			 
			grads
			)
			
    
			
			return
			
			 
			st
			,
			 
			model
			

			end
			

			

			

			
			function
			 
			

	
			step!
			(
			learner
			,
			 
			
			phase
			::

	
			ValidationPhase
			,
			 
			batch
			)
			
			
    
			
			
			xs
			,
			 
			ys
			 
			=
			 
			batch
			
    
			
			

	
			runstep
			(
			learner
			,
			 
			phase
			,
			 
			
			(
			
			;
			
			xs
			=
			xs
			,
			 
			
			ys
			=
			ys
			)
			)
			 
			do
			
			 
			_
			,
			 
			state
			
			
        
			
			
			state
			.
			
			ŷs
			 
			=
			 
			
			
			learner
			.
			
			model
			(
			
			state
			.
			
			xs
			)
			
        
			
			
			state
			.
			
			loss
			 
			=
			 
			
			
			learner
			.
			
			lossfn
			(
			
			state
			.
			
			ŷs
			,
			 
			
			state
			.
			
			ys
			)
			
    
			end
			

			end
			

			

			

			
			
			
			"""
			

			    runepoch(epochfn, learner, phase)

			

			Run `epochfn` inside the context of an epoch. Calls `epochfn(handle)`

			where `handle(e)` can be called to dispatch events.

			

			Takes care of dispatching [`EpochBegin`](#) and [`EpochEnd`](#)

			events as well as handling [`CancelEpochException`](#)s.

			

			"""
			

			
			function
			 
			

	
			runepoch
			(
			epochfn
			,
			 
			learner
			,
			 
			
			phase
			::

	
			Phase
			)
			
			
    
			
			
			handlefn
			(
			e
			)
			 
			=
			 
			

	
			handle
			(
			
			
			learner
			.
			
			callbacks
			.
			
			runner
			,
			 
			e
			,
			 
			phase
			,
			 
			learner
			)
			
    
			
			try
			
			
        
			
			handlefn
			(
			

	
			EpochBegin
			(
			)
			)
			
        
			
			epochfn
			(
			handlefn
			)
			
        
			
			handlefn
			(
			

	
			EpochEnd
			(
			)
			)
			
    
			catch
			 
			e
			
			
        
			
			if
			
			 
			e
			 
			isa
			 

	
			CancelEpochException
			
			
            
			
			@
			debug
			 
			
			"
			Epoch skipped
			"
			
			 
			error
			 
			=
			 
			e
			
            
			
			handlefn
			(
			

	
			EpochEnd
			(
			)
			)
			
        
			else
			
			
            
			
			rethrow
			(
			)
			
        
			end
			
    
			
			
			end
			

			end
			

			

			
			
			
			"""
			

			    runstep(stepfn, learner, phase) -> state

			

			Run `stepfn` inside the context of a step. Calls `stepfn(handle, state)`

			where `handle(e)` can be called to dispatch events and `state` is a [`PropDict`](#)

			which step data, gradients and losses can be written to. Return `state`.

			

			Takes care of dispatching [`StepBegin`](#) and [`StepEnd`](#)

			events as well as handling [`CancelStepException`](#)s.

			"""
			

			
			function
			 
			

	
			runstep
			(
			stepfn
			,
			 
			learner
			,
			 
			
			phase
			::

	
			Phase
			,
			 
			
			initialstate
			 
			=
			 
			
			(
			
			;
			)
			)
			
			
    
			
			state
			 
			=
			 
			

	
			PropDict
			(
			
			pairs
			(
			initialstate
			)
			)
			
    
			
			
			handlefn
			(
			e
			)
			 
			=
			 
			

	
			handle
			(
			
			
			learner
			.
			
			callbacks
			.
			
			runner
			,
			 
			e
			,
			 
			phase
			,
			 
			learner
			)
			
    
			
			try
			
			
        
			
			
			learner
			.
			
			step
			 
			=
			 
			state
			
        
			
			handlefn
			(
			

	
			StepBegin
			(
			)
			)
			
        
			
			stepfn
			(
			handlefn
			,
			 
			state
			)
			
        
			
			handlefn
			(
			

	
			StepEnd
			(
			)
			)
			
        
			
			return
			 
			state
			
    
			catch
			 
			e
			
			
        
			
			if
			
			 
			e
			 
			isa
			 

	
			CancelStepException
			
			
            
			
			@
			debug
			 
			
			"
			Step skipped
			"
			
			 
			error
			 
			=
			 
			e
			
        
			else
			
			
            
			
			rethrow
			(
			)
			
        
			end
			
    
			
			
			end
			
    
			
			return
			 
			state
			

			end

Utilities


			
			
			
			
			
			"""
			

			    fit!(learner, nepochs)

			    fit!(learner, nepochs, (trainiter, validiter))

			

			Train `learner` for `nepochs` of training and validation each. Use data

			iterators that are passed in. If none are given, use `learner.data.training`

			and `learner.data.validation`.

			

			## Examples

			

			```julia

			fit!(learner, 10)

			fit!(learner, 10, (traindl, valdl))

			```

			

			"""
			

			
			function
			 
			

	
			fit!
			(
			learner
			,
			 
			
			nepochs
			::
			Int
			,
			 
			
			(
			trainiter
			,
			 
			validiter
			)
			)
			
			
    
			
			for
			
			 
			i
			 
			in
			
			 
			1
			:
			nepochs
			
			
        
			

	
			epoch!
			(
			learner
			,
			 
			

	
			TrainingPhase
			(
			)
			,
			 
			trainiter
			)
			
        
			

	
			epoch!
			(
			learner
			,
			 
			

	
			ValidationPhase
			(
			)
			,
			 
			validiter
			)
			
    
			end
			

			end
			

			

			
			function
			 
			

	
			fit!
			(
			learner
			,
			 
			
			nepochs
			::
			Int
			)
			
			
    
			

	
			fit!
			(
			learner
			,
			 
			nepochs
			,
			 
			
			(
			
			
			learner
			.
			
			data
			.
			
			training
			,
			 
			
			
			learner
			.
			
			data
			.
			
			validation
			)
			)
			

			end