conditional.jl

FluxTraining/callbacks/conditional.jl is a source file in module FluxTraining

			

High-level interface


			
			
			
			
			
			"""
			

			    throttle(callback, Event, freq = 1)

			    throttle(callback, Event, seconds = 1)

			

			Throttle `Event` type for `callback` so that it is triggered either only every

			`freq`'th time  or every `seconds` seconds.

			

			## Examples

			

			If you want to only sporadically log metrics ([`LogMetrics`](#)) or images

			([`LogVisualization`](#)), `throttle` can be used as follows.

			

			Every 10 steps:

			

			```julia

			callback = throttle(LogMetrics(TensorBoardBackend()), StepEnd, freq = 10)

			learner = Learner(<args>; callbacks=[callback])

			```

			

			Or every 5 seconds:

			

			```julia

			callback = throttle(LogMetrics(TensorBoardBackend()), StepEnd, seconds = 5)

			learner = Learner(<args>; callbacks=[callback])

			```

			"""
			

			
			function
			 
			

	
			throttle
			(
			callback
			,
			 
			
			event
			::
			
			Type
			{
			
			<:

	
			Event
			}
			
			;
			 
			
			freq
			 
			=
			 
			nothing
			,
			 
			
			seconds
			 
			=
			 
			nothing
			)
			
			
    
			
			
			xor
			(
			
			isnothing
			(
			freq
			)
			,
			 
			
			isnothing
			(
			seconds
			)
			)
			 
			||
			 
			
			error
			(
			
			"
			Pass either `every` OR `seconds`.
			"
			)
			
    
			
			if
			 
			
			!
			
			isnothing
			(
			freq
			)
			
			
        
			
			return
			 
			

	
			ConditionalCallback
			(
			callback
			,
			 
			

	
			FrequencyThrottle
			(
			freq
			,
			 
			event
			)
			)
			
    
			
			elseif
			 
			
			!
			
			isnothing
			(
			seconds
			)
			
			
        
			
			return
			 
			

	
			ConditionalCallback
			(
			callback
			,
			 
			

	
			TimeThrottle
			(
			seconds
			,
			 
			event
			)
			)
			
    
			end
			

			end

Implementation


			
			
			
			
			
			"""
			

			    abstract type CallbackCondition

			

			Supertype for conditions to use with [`ConditionalCallback`](#).

			To implement a `CallbackCondition`, implement

			[`shouldrun`](#)`(::MyCondition, event, phase)`.

			

			See [`FrequencyThrottle`](#), [`TimeThrottle`](#) and [`throttle`](#).

			"""
			

			
			abstract
			 
			type
			 

	
			CallbackCondition
			 
			end
			

			

			
			
			
			"""
			

			    ConditionalCallback(callback, condition) <: Callback

			

			Wrapper callback that only forwards events to the wrapped callback

			if [`CallbackCondition`](#) `condition` is met. See [`throttle`](#).

			"""
			

			
			
			struct
			
			 

	
			ConditionalCallback
			 
			<:
			 

	
			Callback
			
			
    
			
			callback
			::

	
			Callback
			
    
			
			condition
			::

	
			CallbackCondition
			

			end
			

			

			
			

	
			stateaccess
			(
			
			cc
			::

	
			ConditionalCallback
			)
			 
			=
			 
			

	
			stateaccess
			(
			
			cc
			.
			
			callback
			)
			

			
			

	
			init!
			(
			
			cc
			::

	
			ConditionalCallback
			,
			 
			learner
			)
			 
			=
			 
			

	
			init!
			(
			
			cc
			.
			
			callback
			,
			 
			learner
			)
			

			

			
			function
			 
			

	
			on
			(
			
			event
			::

	
			Event
			,
			 
			
			phase
			::

	
			Phase
			,
			 
			
			cb
			::

	
			ConditionalCallback
			,
			 
			learner
			)
			
			
    
			
			if
			 
			

	
			shouldrun
			(
			
			cb
			.
			
			condition
			,
			 
			event
			,
			 
			phase
			)
			
			
        
			

	
			on
			(
			event
			,
			 
			phase
			,
			 
			
			cb
			.
			
			callback
			,
			 
			learner
			)
			
    
			end
			

			end
			

			

			

			
			mutable
			
			 
			struct
			
			 

	
			FrequencyThrottle
			 
			<:
			 

	
			CallbackCondition
			
			
    
			freq
			
    
			event
			
    
			counter
			
    
			
			

	
			FrequencyThrottle
			(
			f
			,
			 
			e
			)
			 
			=
			 
			
			new
			(
			f
			,
			 
			e
			,
			 
			1
			)
			

			end
			

			

			
			function
			 
			

	
			shouldrun
			(
			
			c
			::

	
			FrequencyThrottle
			,
			 
			event
			,
			 
			phase
			)
			
			
    
			
			if
			
			 
			
			typeof
			(
			event
			)
			 
			==
			 
			
			c
			.
			
			event
			
			
        
			
			if
			
			 
			
			c
			.
			
			counter
			 
			==
			 
			1
			
			
            
			
			
			c
			.
			
			counter
			 
			=
			 
			
			c
			.
			
			freq
			
            
			
			return
			 
			true
			
        
			else
			
			
            
			
			
			c
			.
			
			counter
			 
			-=
			 
			1
			
            
			
			return
			 
			false
			
        
			end
			
    
			else
			
			
        
			
			return
			 
			true
			
    
			end
			

			end
			

			

			
			mutable
			
			 
			struct
			
			 

	
			TimeThrottle
			 
			<:
			 

	
			CallbackCondition
			
			
    
			seconds
			
    
			event
			
    
			timer
			
    
			
			

	
			TimeThrottle
			(
			s
			,
			 
			e
			,
			 
			
			t
			 
			=
			 
			nothing
			)
			 
			=
			 
			
			new
			(
			s
			,
			 
			e
			,
			 
			t
			)
			

			end
			

			

			

			
			function
			 
			

	
			shouldrun
			(
			
			c
			::

	
			TimeThrottle
			,
			 
			event
			,
			 
			phase
			)
			
			
    
			
			if
			
			 
			
			typeof
			(
			event
			)
			 
			==
			 
			
			c
			.
			
			event
			
			
        
			
			if
			
			 
			
			isnothing
			(
			
			c
			.
			
			timer
			)
			 
			||
			
			 
			(
			
			
			time
			(
			)
			 
			-
			 
			
			c
			.
			
			timer
			)
			 
			>
			 
			
			c
			.
			
			seconds
			
			
            
			
			
			c
			.
			
			timer
			 
			=
			 
			
			time
			(
			)
			
            
			
			return
			 
			true
			
        
			else
			
			
            
			
			return
			 
			false
			
        
			end
			
    
			else
			
			
        
			
			return
			 
			true
			
    
			end
			

			end
			

			

			

			
			@
			testset
			 
			
			"
			throttle
			"
			 
			
			begin
			
    
			
			function
			 
			
			train
			(
			cb
			)
			
			
        
			
			learner
			 
			=
			 
			

	
			testlearner
			(
			cb
			)
			
        
			

	
			epoch!
			(
			learner
			,
			 
			

	
			TrainingPhase
			(
			)
			)
			
        
			
			return
			 
			
			
			
			learner
			.
			
			cbstate
			.
			
			history
			[
			

	
			TrainingPhase
			(
			)
			]
			
    
			end
			

			
    
			
			@
			test
			
			 
			
			
			train
			(
			

	
			Recorder
			(
			)
			)
			.
			
			steps
			 
			==
			 
			16
			
    
			
			@
			testset
			 
			
			"
			freq
			"
			 
			
			begin
			
        
			
			@
			test
			
			 
			
			
			train
			(
			

	
			throttle
			(
			

	
			Recorder
			(
			)
			,
			 

	
			StepEnd
			,
			 
			
			freq
			 
			=
			 
			2
			)
			)
			.
			
			steps
			 
			==
			 
			8
			
        
			
			@
			test
			
			 
			
			
			train
			(
			

	
			throttle
			(
			

	
			Recorder
			(
			)
			,
			 

	
			StepEnd
			,
			 
			
			freq
			 
			=
			 
			16
			)
			)
			.
			
			steps
			 
			==
			 
			1
			
    
			end
			

			
    
			
			if
			 
			
			!
			
			
			
			Base
			.
			
			Sys
			.
			
			iswindows
			(
			)
			
			
        
			
			@
			testset
			 
			
			"
			seconds
			"
			 
			
			begin
			
            
			
			@
			test
			
			 
			
			
			train
			(
			

	
			throttle
			(
			

	
			Recorder
			(
			)
			,
			 

	
			StepEnd
			,
			 
			
			seconds
			 
			=
			 
			0
			)
			)
			.
			
			steps
			 
			==
			 
			16
			
            
			
			@
			test
			
			 
			
			
			train
			(
			

	
			throttle
			(
			

	
			Recorder
			(
			)
			,
			 

	
			StepEnd
			,
			 
			
			seconds
			 
			=
			 
			10
			)
			)
			.
			
			steps
			 
			==
			 
			1
			
        
			end
			
    
			end
			

			end