callbacks.jl

FluxTraining/callbacks/callbacks.jl is a source file in module FluxTraining

			

ProgressPrinter


			
			
			
			
			
			"""
			

			    ProgressPrinter()

			

			Prints a progress bar of the currently running epoch.

			"""
			

			
			mutable
			
			 
			struct
			
			 

	
			ProgressPrinter
			 
			<:
			 

	
			Callback
			
			
    
			
			p
			::
			
			Union
			{
			Nothing
			,
			Progress
			}
			

			end
			

			
			

	
			ProgressPrinter
			(
			)
			 
			=
			 
			

	
			ProgressPrinter
			(
			nothing
			)
			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			::

	
			ProgressPrinter
			)
			 
			=
			 
			
			print
			(
			io
			,
			 
			
			"
			ProgressPrinter()
			"
			)
			

			

			
			function
			 
			

	
			on
			(
			
			::

	
			EpochBegin
			,
			
        
			
			phase
			::

	
			Phase
			,
			
        
			
			cb
			::

	
			ProgressPrinter
			,
			
        
			learner
			)
			
			
    
			
			e
			 
			=
			
			 
			
			
			
			
			learner
			.
			
			cbstate
			.
			
			history
			[
			phase
			]
			.
			
			epochs
			 
			+
			 
			1
			
    
			
			dataiter
			 
			=
			 
			
			get
			(
			
			learner
			.
			
			data
			,
			 
			

	
			phasedataiter
			(
			phase
			)
			,
			 
			nothing
			)
			
    
			
			if
			 
			
			isnothing
			(
			dataiter
			)
			
			
        
			
			
			cb
			.
			
			p
			 
			=
			 
			nothing
			
        
			
			println
			(
			
			"
			Epoch 
			$
			(
			e
			)
			 
			$
			(
			phase
			)
			 ...
			"
			)
			
    
			else
			
			
        
			
			
			cb
			.
			
			p
			 
			=
			 
			
			Progress
			(
			
			length
			(
			dataiter
			)
			,
			 
			
			"
			Epoch 
			$
			(
			e
			)
			 
			$
			(
			phase
			)
			: 
			"
			)
			
    
			end
			

			end
			

			

			
			

	
			on
			(
			
			::

	
			StepEnd
			,
			 
			
			::

	
			Phase
			,
			 
			
			cb
			::

	
			ProgressPrinter
			,
			 
			learner
			)
			 
			=
			
			 
			
			isnothing
			(
			
			cb
			.
			
			p
			)
			 
			||
			 
			
			next!
			(
			
			cb
			.
			
			p
			)
			

			

			
			

	
			runafter
			(
			
			::

	
			ProgressPrinter
			)
			 
			=
			 
			
			(

	
			Recorder
			,
			)
			

			
			

	
			stateaccess
			(
			
			::

	
			ProgressPrinter
			)
			 
			=
			 
			
			(
			
			data
			 
			=
			 
			

	
			Read
			(
			)
			,
			 
			
			cbstate
			 
			=
			 
			(
			
			history
			 
			=
			 
			

	
			Read
			(
			)
			)
			,
			)
			

			

			

			
			
			
			"""
			

			    MetricsPrinter() <: Callback

			

			Callback that prints metrics after every epoch. Relies on the metrics computed by

			[`Metrics`](#), so will error if no `Metrics` callback is used.

			

			This callback is added by default to every [`Learner`](#) unless you pass in

			`usedefaultcallbacks = false`.

			"""
			

			
			
			struct
			
			 

	
			MetricsPrinter
			 
			<:
			 

	
			Callback
			
			 
			end
			

			

			
			function
			 
			

	
			on
			(
			
			::

	
			EpochEnd
			,
			
        
			
			phase
			::

	
			Phase
			,
			
        
			
			cb
			::

	
			MetricsPrinter
			,
			
        
			learner
			)
			
			
    
			
			mvhistory
			 
			=
			 
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsepoch
			[
			phase
			]
			
    
			
			epoch
			 
			=
			 
			
			
			
			
			learner
			.
			
			cbstate
			.
			
			history
			[
			phase
			]
			.
			
			epochs
			
    
			

	
			print_epoch_table
			(
			mvhistory
			,
			 
			epoch
			,
			 
			phase
			)
			

			end
			

			

			

			
			function
			 
			

	
			print_epoch_table
			(
			mvhistory
			,
			 
			epoch
			,
			 
			phase
			)
			
			
    
			
			header
			 
			=
			 
			
			vcat
			(
			
			[
			
			"
			Phase
			"
			,
			 
			
			"
			Epoch
			"
			]
			,
			 
			
			string
			.
			
			(
			
			keys
			(
			mvhistory
			)
			)
			)
			
    
			
			vals
			 
			=
			 
			
			[
			
			
			
			last
			(
			mvhistory
			,
			 
			key
			)
			 
			|>
			 
			last
			 
			for
			
			 
			key
			 
			in
			 
			
			keys
			(
			mvhistory
			)
			]
			
    
			
			data
			 
			=
			 
			
			reshape
			(
			
			vcat
			(
			
			[
			
			string
			(
			phase
			)
			,
			 
			epoch
			]
			,
			 
			vals
			)
			,
			 
			1
			,
			 
			:
			)
			
    
			
			pretty_table
			(
			data
			
			;
			 
			
			header
			 
			=
			 
			header
			,
			 
			
			formatters
			 
			=
			 
			
			
			PrettyTables
			.
			
			ft_round
			(
			5
			)
			)
			

			end
			

			

			
			

	
			stateaccess
			(
			
			::

	
			MetricsPrinter
			)
			 
			=
			 
			
			(
			
			;
			 
			
			cbstate
			 
			=
			 
			
			(
			
			metricsepoch
			 
			=
			 
			

	
			Read
			(
			)
			,
			 
			
			history
			 
			=
			 
			

	
			Read
			(
			)
			)
			)
			

			
			

	
			runafter
			(
			
			::

	
			MetricsPrinter
			)
			 
			=
			 
			
			(

	
			Metrics
			,
			)

StopOnNaNLoss


			
			
			
			
			
			"""
			

			    StopOnNaNLoss()

			

			Stops the training when a NaN loss is encountered.

			

			This callback is added by default to every [`Learner`](#) unless you pass in

			`usedefaultcallbacks = false`.

			"""
			

			
			
			struct
			
			 

	
			StopOnNaNLoss
			 
			<:
			 

	
			Callback
			
			 
			end
			

			

			
			function
			 
			

	
			on
			(
			
			::

	
			BackwardEnd
			,
			 
			
			::

	
			AbstractTrainingPhase
			,
			 
			
			::

	
			StopOnNaNLoss
			,
			 
			learner
			)
			
			
    
			
			
			!
			
			isnan
			(
			
			
			learner
			.
			
			step
			.
			
			loss
			)
			 
			||
			 
			
			throw
			(
			

	
			CancelFittingException
			(
			
			"
			Encountered NaN loss
			"
			)
			)
			

			end
			

			

			
			

	
			stateaccess
			(
			
			::

	
			StopOnNaNLoss
			)
			 
			=
			 
			
			(
			
			step
			 
			=
			 
			(
			
			loss
			 
			=
			 
			

	
			Read
			(
			)
			)
			,
			)
			

			

			

			

			
			
			
			"""
			

			    ToDevice(movedatafn, movemodelfn) <: Callback

			

			Moves model and step data to a device using `movedatafn` for step data

			and `movemodelfn` for the model. For example `ToDevice(Flux.gpu, Flux.gpu)`,

			moves them to a GPU if available. See [`ToGPU`](#).

			

			By default, only moves `step.xs` and `step.ys`, but this can be extended

			to other state by implementing `on(::StepBegin, ::MyCustomPhase, ::ToDevice, learner)`.

			"""
			

			
			
			struct
			
			 

	
			ToDevice
			 
			<:
			 

	
			Callback
			
			
    
			movedatafn
			
    
			movemodelfn
			

			end
			

			

			

			
			function
			 
			

	
			on
			(
			
			::

	
			EpochBegin
			,
			 
			
			::

	
			Phase
			,
			 
			
			cb
			::

	
			ToDevice
			,
			 
			learner
			)
			
			
    
			

	
			model!
			(
			learner
			,
			 
			
			
			cb
			.
			
			movemodelfn
			(
			
			learner
			.
			
			model
			)
			)
			

			end
			

			

			
			

	
			stateaccess
			(
			
			::

	
			ToDevice
			)
			 
			=
			 
			
			(
			
    
			
			model
			 
			=
			 
			

	
			Write
			(
			)
			,
			
    
			
			params
			 
			=
			 
			

	
			Write
			(
			)
			,
			
    
			
			step
			 
			=
			 
			

	
			Write
			(
			)
			,
			
    
			
			optimizer
			 
			=
			 
			

	
			Read
			(
			)
			,
			

			)
			

			

			
			
			
			"""
			

			    ToGPU()

			

			Callback that moves model and batch data to the GPU during training.

			Convenience for [`ToDevice`](#)`(Flux.gpu)`.

			"""
			

			
			

	
			ToGPU
			(
			)
			 
			=
			 
			

	
			ToDevice
			(

	
			gpu
			,
			 

	
			gpu
			)
			

			

			

			
			function
			 
			

	
			on
			(
			
			::

	
			StepBegin
			,
			 
			
			::

	
			Phase
			,
			 
			
			cb
			::

	
			ToDevice
			,
			 
			learner
			)
			
			
    
			
			
			
			learner
			.
			
			step
			.
			
			xs
			 
			=
			 
			
			
			cb
			.
			
			movedatafn
			(
			
			
			learner
			.
			
			step
			.
			
			xs
			)
			
    
			
			
			
			learner
			.
			
			step
			.
			
			ys
			 
			=
			 
			
			
			cb
			.
			
			movedatafn
			(
			
			
			learner
			.
			
			step
			.
			
			ys
			)
			

			end
			

			

			

			
			function
			 
			

	
			garbagecollect
			(
			)
			
			
    
			
			
			GC
			.
			
			gc
			(
			)
			
    
			
			if
			 
			
			
			
			Base
			.
			
			Sys
			.
			
			islinux
			(
			)
			
			
        
			
			ccall
			(
			
			:
			malloc_trim
			,
			 
			Cvoid
			,
			 
			
			(
			Cint
			,
			)
			,
			 
			0
			)
			
    
			end
			

			end
			

			

			

			
			
			
			"""
			

			    GarbageCollect(nsteps)

			

			Every `nsteps` steps, forces garbage collection.

			Use this if you get memory leaks from, for example,

			parallel data loading.

			

			Performs an additional C-call on Linux systems that can

			sometimes help.

			"""
			

			
			function
			 
			

	
			GarbageCollect
			(
			
			
			nsteps
			::
			Int
			 
			=
			 
			100
			)
			
			
    
			
			return
			 
			

	
			throttle
			(
			
        
			

	
			CustomCallback
			(
			
			(
			learner
			)
			 
			->
			 
			

	
			garbagecollect
			(
			)
			,
			 

	
			StepEnd
			,
			 

	
			Phase
			)
			,
			
        

	
			StepEnd
			,
			
        
			
			freq
			 
			=
			 
			nsteps
			)
			

			end