trace.jl

FluxTraining/callbacks/trace.jl is a source file in module FluxTraining

			
			
			
			
			
			"""
			

			    Traces(preprocess[, phase])

			

			Record a trace during `phase` by apply each pre-processing function in

			`preprocess` to the [`Learner`](#) to produce a trace value.

			The trace is recorded at the end of each learning step.

			

			See [`LogTraces`](#) for logging of the trace values.

			```julia

			cb = Traces((loss2 = learner -> learner.step.loss^2,

			             avg_gnorm = learner -> mean(map((_, g) -> norm(g), pairs(learner.step.grads))))

			            TrainingPhase)

			```

			"""
			

			
			
			struct
			
			 
			

	
			Traces
			{
			
			P
			<:

	
			Phase
			}
			 
			<:
			 

	
			Callback
			
			
    
			
			preprocess
			::
			NamedTuple
			

			end
			

			

			
			function
			 
			

	
			Traces
			(
			
			preprocess
			::
			NamedTuple
			,
			 
			
			
			P
			::
			
			Type
			{
			
			<:

	
			Phase
			}
			 
			=
			 

	
			Phase
			)
			
			
    
			
			return
			 
			
			

	
			Traces
			{
			P
			}
			(
			preprocess
			)
			

			end
			

			

			
			

	
			stateaccess
			(
			
			::

	
			Traces
			)
			 
			=
			 
			
			(
			
			cbstate
			 
			=
			 
			
			(
			
			history
			 
			=
			 
			

	
			Read
			(
			)
			,
			 
			
			tracehistory
			 
			=
			 
			

	
			Write
			(
			)
			)
			,
			
                         
			
			step
			 
			=
			 
			

	
			Read
			(
			)
			)
			

			

			
			function
			 
			

	
			init!
			(
			
			traces
			::

	
			Traces
			,
			 
			learner
			)
			
			
    
			
			
			
			length
			(
			
			traces
			.
			
			preprocess
			)
			 
			==
			 
			
			length
			(
			
			unique
			(
			
			keys
			(
			
			traces
			.
			
			preprocess
			)
			)
			)
			 
			||
			
        
			
			error
			(
			
			"
			Multiple traces have the same name!
			"
			)
			
    
			
			if
			 
			
			!
			
			haskey
			(
			
			learner
			.
			
			cbstate
			,
			 
			
			:
			tracehistory
			)
			
			
        
			
			
			
			learner
			.
			
			cbstate
			.
			
			tracehistory
			 
			=
			 
			
			
			DefaultDict
			{

	
			Phase
			,
			MVHistory
			}
			(
			
			
			(
			)
			 
			->
			 
			
			MVHistory
			(
			)
			)
			
    
			end
			

			end
			

			

			
			function
			 
			
			

	
			on
			(
			
			::

	
			StepEnd
			,
			 
			
			phase
			::
			P
			,
			 
			
			traces
			::
			

	
			Traces
			{
			P
			}
			,
			 
			learner
			)
			 
			where
			 
			{
			
			P
			<:

	
			Phase
			}
			
			
    
			
			step
			 
			=
			 
			
			
			
			
			learner
			.
			
			cbstate
			.
			
			history
			[
			phase
			]
			.
			
			steps
			
    
			
			history
			 
			=
			 
			
			
			
			learner
			.
			
			cbstate
			.
			
			tracehistory
			[
			phase
			]
			
    
			
			for
			
			 
			
			(
			trace_name
			,
			 
			f
			)
			 
			in
			 
			
			pairs
			(
			
			traces
			.
			
			preprocess
			)
			
			
        
			
			val
			 
			=
			 
			
			f
			(
			learner
			)
			
        
			
			push!
			(
			history
			,
			 
			trace_name
			,
			 
			step
			,
			 
			val
			)
			
    
			end
			

			end
			

			

			
			@
			testset
			 
			
			"
			Traces
			"
			 
			
			begin
			
    
			
			cb
			 
			=
			 
			

	
			Traces
			(
			
			(
			
			keya
			 
			=
			 
			
			learner
			 
			->
			 
			
			sum
			(
			
			
			learner
			.
			
			step
			.
			
			ys
			)
			,
			
                 
			
			keyb
			 
			=
			 
			
			learner
			 
			->
			 
			
			sum
			(
			
			
			learner
			.
			
			step
			.
			
			ŷs
			)
			)
			,
			
                

	
			ValidationPhase
			)
			
    
			
			learner
			 
			=
			 
			

	
			testlearner
			(
			

	
			Recorder
			(
			)
			,
			 
			cb
			)
			
    
			
			@
			test_nowarn
			 
			

	
			fit!
			(
			learner
			,
			 
			1
			)
			
    
			
			@
			test
			
			 
			
			:
			keya
			 
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			tracehistory
			[
			

	
			ValidationPhase
			(
			)
			]
			)
			
    
			
			@
			test
			
			 
			
			:
			keyb
			 
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			tracehistory
			[
			

	
			ValidationPhase
			(
			)
			]
			)
			
    
			
			@
			test
			
			 
			
			:
			keya
			 
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			tracehistory
			[
			

	
			TrainingPhase
			(
			)
			]
			)
			
    
			
			@
			test
			
			 
			
			:
			keyb
			 
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			tracehistory
			[
			

	
			TrainingPhase
			(
			)
			]
			)
			

			end