callbacks.jl

YVIq2/src/callbacks/ is a source file in module YVIq2

			

ProgressPrinter


			
			
			
			mutable
			
			
			
			 
			
			struct
			
			
			 
			
	
		
			ProgressPrinter
			
			 
			
			<:
			
			 
			
	
		
			Callback
			
			
			
    
			
			
			p
			
			::
			
			
			Union
			
			{
			
			Nothing
			
			,
			
			Progress
			
			}
			
			

			
			end
			
			

			
			
			
	
		
			ProgressPrinter
			
			(
			
			)
			
			 
			
			=
			
			 
			
			
	
		
			ProgressPrinter
			
			(
			
			nothing
			
			)
			
			

			
			
			
			
			Base
			
			.
			
			
			show
			
			(
			
			
			io
			
			::
			
			IO
			
			,
			
			 
			
			
			::
			
	
		
			ProgressPrinter
			
			)
			
			 
			
			=
			
			 
			
			
			print
			
			(
			
			io
			
			,
			
			 
			
			"
			
			ProgressPrinter()
			
			"
			
			)
			
			

			
			

			
			
			function
			
			 
			
			
	
		
			on
			
			(
			
			
			::
			
	
		
			EpochBegin
			
			,
			
			
        
			
			
			phase
			
			::
			
	
		
			Phase
			
			,
			
			
        
			
			
			cb
			
			::
			
	
		
			ProgressPrinter
			
			,
			
			
        
			
			learner
			
			)
			
			
			
    
			
			
			e
			
			 
			
			=
			
			
			 
			
			
			
			
			
			learner
			
			.
			
			
			cbstate
			
			.
			
			
			history
			
			[
			
			phase
			
			]
			
			.
			
			
			epochs
			
			 
			
			+
			
			 
			
			1
			
			
    
			
			
			dataiter
			
			 
			
			=
			
			 
			
			
			get
			
			(
			
			
			learner
			
			.
			
			
			data
			
			,
			
			 
			
			
	
		
			phasedataiter
			
			(
			
			phase
			
			)
			
			,
			
			 
			
			nothing
			
			)
			
			
    
			
			
			if
			
			 
			
			
			isnothing
			
			(
			
			dataiter
			
			)
			
			
			
        
			
			
			
			cb
			
			.
			
			
			p
			
			 
			
			=
			
			 
			
			nothing
			
			
        
			
			
			println
			
			(
			
			
			"
			
			Epoch 
			
			$
			
			(
			
			e
			
			)
			
			 
			
			$
			
			(
			
			phase
			
			)
			
			 ...
			
			"
			
			)
			
			
    
			
			else
			
			
			
        
			
			
			
			cb
			
			.
			
			
			p
			
			 
			
			=
			
			 
			
			
			Progress
			
			(
			
			
			length
			
			(
			
			dataiter
			
			)
			
			,
			
			 
			
			
			"
			
			Epoch 
			
			$
			
			(
			
			e
			
			)
			
			 
			
			$
			
			(
			
			phase
			
			)
			
			: 
			
			"
			
			)
			
			
    
			
			end
			
			

			
			end
			
			

			
			

			
			
			
	
		
			on
			
			(
			
			
			::
			
	
		
			StepEnd
			
			,
			
			 
			
			
			::
			
	
		
			Phase
			
			,
			
			 
			
			
			cb
			
			::
			
	
		
			ProgressPrinter
			
			,
			
			 
			
			learner
			
			)
			
			 
			
			=
			
			
			 
			
			
			isnothing
			
			(
			
			
			cb
			
			.
			
			
			p
			
			)
			
			 
			
			||
			
			 
			
			
			next!
			
			(
			
			
			cb
			
			.
			
			
			p
			
			)
			
			

			
			

			
			
			
	
		
			runafter
			
			(
			
			
			::
			
	
		
			ProgressPrinter
			
			)
			
			 
			
			=
			
			 
			
			
			(
			
	
		
			Recorder
			
			,
			
			)
			
			

			
			
			
	
		
			stateaccess
			
			(
			
			
			::
			
	
		
			ProgressPrinter
			
			)
			
			 
			
			=
			
			 
			
			
			(
			
			
			data
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			,
			
			 
			
			
			cbstate
			
			 
			
			=
			
			 
			
			(
			
			
			history
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			)
			
			,
			
			)
			
			

			
			

			
			

			
			
			
			
			struct
			
			
			 
			
	
		
			MetricsPrinter
			
			 
			
			<:
			
			 
			
	
		
			Callback
			
			
			 
			
			end
			
			

			
			

			
			
			function
			
			 
			
			
	
		
			on
			
			(
			
			
			::
			
	
		
			EpochEnd
			
			,
			
			
        
			
			
			phase
			
			::
			
	
		
			Phase
			
			,
			
			
        
			
			
			cb
			
			::
			
	
		
			MetricsPrinter
			
			,
			
			
        
			
			learner
			
			)
			
			
			
    
			
			
			mvhistory
			
			 
			
			=
			
			 
			
			
			
			
			learner
			
			.
			
			
			cbstate
			
			.
			
			
			metricsepoch
			
			[
			
			phase
			
			]
			
			
    
			
			
			epoch
			
			 
			
			=
			
			 
			
			
			
			
			
			learner
			
			.
			
			
			cbstate
			
			.
			
			
			history
			
			[
			
			phase
			
			]
			
			.
			
			
			epochs
			
			
    
			
			
	
		
			print_epoch_table
			
			(
			
			mvhistory
			
			,
			
			 
			
			epoch
			
			,
			
			 
			
			phase
			
			)
			
			

			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			print_epoch_table
			
			(
			
			mvhistory
			
			,
			
			 
			
			epoch
			
			,
			
			 
			
			phase
			
			)
			
			
			
    
			
			
			header
			
			 
			
			=
			
			 
			
			
			vcat
			
			(
			
			
			[
			
			"
			
			Phase
			
			"
			
			,
			
			 
			
			"
			
			Epoch
			
			"
			
			]
			
			,
			
			 
			
			
			string
			
			.
			
			
			(
			
			
			keys
			
			(
			
			mvhistory
			
			)
			
			)
			
			)
			
			
    
			
			
			vals
			
			 
			
			=
			
			 
			
			
			[
			
			
			
			
			last
			
			(
			
			mvhistory
			
			,
			
			 
			
			key
			
			)
			
			 
			
			|>
			
			 
			
			last
			
			 
			
			for
			
			
			 
			
			key
			
			 
			
			in
			
			 
			
			
			keys
			
			(
			
			mvhistory
			
			)
			
			]
			
			
    
			
			
			data
			
			 
			
			=
			
			 
			
			
			reshape
			
			(
			
			
			vcat
			
			(
			
			
			[
			
			
			string
			
			(
			
			
			typeof
			
			(
			
			phase
			
			)
			
			)
			
			,
			
			 
			
			epoch
			
			]
			
			,
			
			 
			
			vals
			
			)
			
			,
			
			 
			
			1
			
			,
			
			 
			
			:
			
			)
			
			
    
			
			
			pretty_table
			
			(
			
			data
			
			
			;
			
			 
			
			
			header
			
			 
			
			=
			
			 
			
			header
			
			,
			
			 
			
			
			formatters
			
			 
			
			=
			
			 
			
			
			
			PrettyTables
			
			.
			
			
			ft_round
			
			(
			
			5
			
			)
			
			)
			
			

			
			end
			
			

			
			

			
			
			
	
		
			stateaccess
			
			(
			
			
			::
			
	
		
			MetricsPrinter
			
			)
			
			 
			
			=
			
			 
			
			
			(
			
			
			;
			
			 
			
			
			cbstate
			
			 
			
			=
			
			 
			
			
			(
			
			
			metricsepoch
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			,
			
			 
			
			
			history
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			)
			
			)
			
			

			
			
			
	
		
			runafter
			
			(
			
			
			::
			
	
		
			MetricsPrinter
			
			)
			
			 
			
			=
			
			 
			
			
			(
			
	
		
			Metrics
			
			,
			
			)

StopOnNaNLoss


			
			
			
			
			
			struct
			
			
			 
			
	
		
			StopOnNaNLoss
			
			 
			
			<:
			
			 
			
	
		
			Callback
			
			
			 
			
			end
			
			

			
			

			
			
			function
			
			 
			
			
	
		
			on
			
			(
			
			
			::
			
	
		
			BackwardEnd
			
			,
			
			 
			
			
			::
			
	
		
			AbstractTrainingPhase
			
			,
			
			 
			
			
			::
			
	
		
			StopOnNaNLoss
			
			,
			
			 
			
			learner
			
			)
			
			
			
    
			
			
			
			!
			
			
			isnan
			
			(
			
			
			
			learner
			
			.
			
			
			step
			
			.
			
			
			loss
			
			)
			
			 
			
			||
			
			 
			
			
			throw
			
			(
			
			
	
		
			CancelFittingException
			
			(
			
			"
			
			Encountered NaN loss
			
			"
			
			)
			
			)
			
			

			
			end
			
			

			
			

			
			
			
	
		
			stateaccess
			
			(
			
			
			::
			
	
		
			StopOnNaNLoss
			
			)
			
			 
			
			=
			
			 
			
			
			(
			
			
			step
			
			 
			
			=
			
			 
			
			(
			
			
			loss
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			)
			
			,
			
			)
			
			

			
			

			
			

			
			

			
			
			
			
			struct
			
			
			 
			
	
		
			ToDevice
			
			 
			
			<:
			
			 
			
	
		
			Callback
			
			
			
    
			
			movedatafn
			
			
    
			
			movemodelfn
			
			

			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			on
			
			(
			
			
			::
			
	
		
			EpochBegin
			
			,
			
			 
			
			
			::
			
	
		
			Phase
			
			,
			
			 
			
			
			cb
			
			::
			
	
		
			ToDevice
			
			,
			
			 
			
			learner
			
			)
			
			
			
    
			
			
	
		
			model!
			
			(
			
			learner
			
			,
			
			 
			
			
			
			cb
			
			.
			
			
			movemodelfn
			
			(
			
			
			learner
			
			.
			
			
			model
			
			)
			
			)
			
			

			
			end
			
			

			
			

			
			
			
	
		
			stateaccess
			
			(
			
			
			::
			
	
		
			ToDevice
			
			)
			
			 
			
			=
			
			 
			
			
			(
			
			
    
			
			
			model
			
			 
			
			=
			
			 
			
			
	
		
			Write
			
			(
			
			)
			
			,
			
			
    
			
			
			params
			
			 
			
			=
			
			 
			
			
	
		
			Write
			
			(
			
			)
			
			,
			
			
    
			
			
			step
			
			 
			
			=
			
			 
			
			
	
		
			Write
			
			(
			
			)
			
			,
			
			
    
			
			
			optimizer
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			,
			
			

			
			)
			
			

			
			

			
			
			
	
		
			ToGPU
			
			(
			
			)
			
			 
			
			=
			
			 
			
			
	
		
			ToDevice
			
			(
			
			gpu
			
			,
			
			 
			
			gpu
			
			)
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			on
			
			(
			
			
			::
			
	
		
			StepBegin
			
			,
			
			 
			
			
			::
			
	
		
			Phase
			
			,
			
			 
			
			
			cb
			
			::
			
	
		
			ToDevice
			
			,
			
			 
			
			learner
			
			)
			
			
			
    
			
			
			
			
			learner
			
			.
			
			
			step
			
			.
			
			
			xs
			
			 
			
			=
			
			 
			
			
			
			cb
			
			.
			
			
			movedatafn
			
			(
			
			
			
			learner
			
			.
			
			
			step
			
			.
			
			
			xs
			
			)
			
			
    
			
			
			
			
			learner
			
			.
			
			
			step
			
			.
			
			
			ys
			
			 
			
			=
			
			 
			
			
			
			cb
			
			.
			
			
			movedatafn
			
			(
			
			
			
			learner
			
			.
			
			
			step
			
			.
			
			
			ys
			
			)
			
			

			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			garbagecollect
			
			(
			
			)
			
			
			
    
			
			
			
			GC
			
			.
			
			
			gc
			
			(
			
			)
			
			
    
			
			
			if
			
			 
			
			
			
			
			Base
			
			.
			
			
			Sys
			
			.
			
			
			islinux
			
			(
			
			)
			
			
			
        
			
			
			ccall
			
			(
			
			
			:
			
			malloc_trim
			
			,
			
			 
			
			Cvoid
			
			,
			
			 
			
			
			(
			
			Cint
			
			,
			
			)
			
			,
			
			 
			
			0
			
			)
			
			
    
			
			end
			
			

			
			end
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			GarbageCollect
			
			(
			
			
			
			nsteps
			
			::
			
			Int
			
			 
			
			=
			
			 
			
			100
			
			)
			
			
			
    
			
			
			return
			
			 
			
			
	
		
			throttle
			
			(
			
			
        
			
			
	
		
			CustomCallback
			
			(
			
			
			(
			
			learner
			
			)
			
			 
			
			->
			
			 
			
			
	
		
			garbagecollect
			
			(
			
			)
			
			,
			
			 
			
	
		
			StepEnd
			
			,
			
			 
			
	
		
			Phase
			
			)
			
			,
			
			
        
			
	
		
			StepEnd
			
			,
			
			
        
			
			
			freq
			
			 
			
			=
			
			 
			
			nsteps
			
			)
			
			

			
			end