scheduler.jl

FluxTraining/callbacks/scheduler.jl is a source file in module FluxTraining

			
			
			
			
			
			"""
			

			    Scheduler(schedules...)

			

			Callback for hyperparameter scheduling. Takes pairs of [`HyperParameter`](#)

			types and [ParameterSchedulers.jl schedules](https://github.com/FluxML/ParameterSchedulers.jl).

			

			See [the tutorial](/docs/docs/tutorials/hyperparameters.md) for more information.

			

			## Example

			

			```julia

			es = length(learner.data.training)

			lrschedule = ParameterSchedulers.Step(;λ=1.0, γ=0.9, step_sizes=[10, 20])

			scheduler = Scheduler(

			    LearningRate => lrschedule

			)

			```

			"""
			

			
			mutable
			
			 
			struct
			
			 

	
			Scheduler
			 
			<:
			 

	
			Callback
			
			
    
			
			schedules
			::
			
			Dict
			{
			
			Type
			{
			
			<:

	
			HyperParameter
			}
			,
			 
			
			ParameterSchedulers
			.
			
			AbstractSchedule
			}
			
    
			
			step
			::
			Int
			
    
			
			

	
			Scheduler
			(
			
			args
			...
			
			;
			 
			
			kwargs
			...
			)
			 
			=
			 
			
			new
			(
			
			Dict
			(
			
			args
			...
			
			;
			 
			
			kwargs
			...
			)
			,
			 
			1
			)
			

			end
			

			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			scheduler
			::

	
			Scheduler
			)
			 
			=
			 
			
			print
			(
			
    
			io
			,
			 
			
			"
			Scheduler(
			"
			,
			 
			
			join
			(
			
			keys
			(
			
			scheduler
			.
			
			schedules
			)
			,
			 
			
			"
			, 
			"
			)
			,
			 
			
			"
			)
			"
			)
			

			

			

			

			
			function
			 
			

	
			stateaccess
			(
			
			scheduler
			::

	
			Scheduler
			)
			
			
    
			# TODO: implement proper merging of permissions
			
    
			
			if
			
			 
			
			length
			(
			
			keys
			(
			
			scheduler
			.
			
			schedules
			)
			)
			 
			==
			 
			0
			
			
        
			
			hpstateaccess
			 
			=
			 
			
			(
			
			;
			)
			
    
			else
			
			
        
			
			hpstateaccess
			 
			=
			 
			
			merge
			(
			
			

	
			stateaccess
			.
			
			(
			
			keys
			(
			
			scheduler
			.
			
			schedules
			)
			)
			...
			)
			
    
			end
			
    
			
			return
			 
			
			(
			
        
			
			data
			 
			=
			 
			

	
			Read
			(
			)
			,
			 
			
			cbstate
			 
			=
			 
			
			(
			
			;
			 
			
			hyperparams
			=
			

	
			Write
			(
			)
			,
			 
			
			history
			 
			=
			 
			

	
			Read
			(
			)
			)
			,
			
        
			
			hpstateaccess
			...
			
    
			)
			

			end
			

			

			
			

	
			resolveconflict
			(
			
			::

	
			Scheduler
			,
			 
			
			to
			::

	
			ToDevice
			)
			 
			=
			 
			

	
			RunFirst
			(
			to
			)
			

			

			
			function
			 
			

	
			init!
			(
			
			scheduler
			::

	
			Scheduler
			,
			 
			learner
			)
			
			
    
			
			if
			 
			
			!
			
			haskey
			(
			
			learner
			.
			
			cbstate
			,
			 
			
			:
			hyperparams
			)
			
			
        
			
			
			
			learner
			.
			
			cbstate
			.
			
			hyperparams
			 
			=
			 
			
			MVHistory
			(
			)
			
    
			end
			
    
			
			
			scheduler
			.
			
			step
			 
			=
			 
			1
			

			end
			

			

			

			
			function
			 
			

	
			on
			(
			
			::

	
			StepBegin
			,
			 
			
			phase
			::

	
			AbstractTrainingPhase
			,
			 
			
			scheduler
			::

	
			Scheduler
			,
			 
			learner
			)
			
			
    
			
			step
			 
			=
			 
			
			scheduler
			.
			
			step
			
    
			
			for
			
			 
			
			(
			H
			,
			 
			schedule
			)
			 
			in
			 
			
			scheduler
			.
			
			schedules
			
			
        
			
			value
			 
			=
			 
			
			schedule
			(
			step
			)
			
        
			

	
			sethyperparameter!
			(
			learner
			,
			 
			H
			,
			 
			value
			)
			
        
			
			push!
			(
			
            
			
			
			learner
			.
			
			cbstate
			.
			
			hyperparams
			,
			
            
			
			Symbol
			(
			H
			)
			,
			
            
			
			
			
			
			learner
			.
			
			cbstate
			.
			
			history
			[
			phase
			]
			.
			
			steps
			,
			
            
			value
			)
			
    
			end
			
    
			
			
			scheduler
			.
			
			step
			 
			+=
			 
			1
			

			end
			

			

			

			
			
			
			"""
			

			    onecycle(nsteps, max_val, [start_val, end_val; pct_start])

			

			Creates a one-cycle [`Schedule`](#) over `nsteps` steps from `start_val`

			over `max_val` to `end_val`.

			

			## Examples

			

			```julia

			

			epochlength = length(traindataiter)

			cb = Scheduler(LearningRate => onecycle(10epochlength, 0.01))

			learner = Learner(<args>...; callbacks=[cb])

			```

			"""
			

			
			function
			 
			

	
			onecycle
			(
			
    
			nsteps
			,
			 
			max_val
			
			;
			
    
			
			pct_start
			=
			0.25
			,
			
    
			
			div
			=
			25
			,
			 
			
			divfinal
			=
			1e5
			,
			
    
			
			start_val
			=
			
			max_val
			 
			/
			 
			div
			,
			 
			
			end_val
			=
			
			max_val
			 
			/
			 
			divfinal
			)
			
			
    
			
			warmup
			 
			=
			 
			
			ceil
			(
			Int
			,
			 
			
			nsteps
			 
			*
			 
			pct_start
			)
			
    
			
			warmdown
			 
			=
			
			 
			nsteps
			 
			-
			 
			warmup
			

			
    
			
			Sequence
			(
			
			
			Sin
			(
			
			λ0
			=
			max_val
			,
			 
			
			λ1
			=
			start_val
			,
			 
			
			period
			=
			
			2
			*
			warmup
			)
			 
			=>
			 
			warmup
			,
			
             
			
			
			Shifted
			(
			
			Sin
			(
			
			λ0
			=
			max_val
			,
			 
			
			λ1
			=
			end_val
			,
			 
			
			period
			=
			
			2
			*
			warmdown
			)
			,
			 
			
			warmdown
			 
			+
			 
			1
			)
			 
			=>
			 
			warmdown
			)
			

			end