metrics.jl

FluxTraining/callbacks/metrics.jl is a source file in module FluxTraining

			
			
			
			
			
			"""
			

			    Metrics(metrics...) <: Callback

			

			Callback that tracks metrics during training.

			

			You can pass any number of `metrics` with every argument being

			- an [`AbstractMetric`](#) like [`Metric`](#); or

			- a function `f(ŷs, ys) -> val`

			

			This callback is added by default to every [`Learner`](#) unless you pass in

			`usedefaultcallbacks = false`. A metric tracking `learner.lossfn` [`Loss`](#)

			is included by default.

			

			The computed metrics can be access in `learner.cbstate.metricsstep` and

			`learner.cbstate.metricsepoch` for steps and epochs, respectively.

			

			## Examples

			

			Track [`accuracy`](#):

			

			```julia

			cb = Metrics(accuracy)

			```

			

			Pass in [`Metric`]s:

			

			```julia

			cb = Metrics(

			    Metric(Flux.mse, device = gpu),

			    Metric(Flux.mae, device = gpu)

			)

			```

			

			"""
			

			
			
			struct
			
			 

	
			Metrics
			 
			<:
			 

	
			Callback
			
			
    
			
			metrics
			::
			Tuple
			
    
			
			function
			 
			

	
			Metrics
			(
			
			metrics
			...
			)
			
			
        
			
			return
			 
			
			new
			(
			
            
			
			Tuple
			(
			
			
			
			m
			 
			isa
			 

	
			AbstractMetric
			 
			?
			 
			m
			 
			:
			 
			

	
			Metric
			(
			m
			)
			 
			for
			
			 
			m
			 
			in
			 
			
			(
			

	
			Loss
			(
			)
			,
			 
			
			metrics
			...
			)
			)
			,
			
        
			)
			
    
			end
			

			end
			

			

			
			

	
			runafter
			(
			
			::

	
			Metrics
			)
			 
			=
			 
			
			(

	
			Recorder
			,
			)
			

			
			

	
			stateaccess
			(
			
			::

	
			Metrics
			)
			 
			=
			 
			
			(
			
    
			
			cbstate
			 
			=
			 
			
			(
			
			metricsstep
			 
			=
			 
			

	
			Write
			(
			)
			,
			 
			
			metricsepoch
			 
			=
			 
			

	
			Write
			(
			)
			,
			 
			
			history
			 
			=
			 
			

	
			Read
			(
			)
			)
			,
			
    
			
			step
			 
			=
			 
			

	
			Read
			(
			)
			,
			

			)
			

			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			metrics
			::

	
			Metrics
			)
			 
			=
			
    
			
			print
			(
			io
			,
			 
			
			"
			Metrics(
			"
			,
			 
			
			join
			(
			
			string
			.
			
			(
			
			metrics
			.
			
			metrics
			)
			,
			 
			
			"
			, 
			"
			)
			,
			 
			
			"
			)
			"
			)

store metrics in cbstate so other callbacks can access them


			
			
			
			function
			 
			

	
			init!
			(
			
			metrics
			::

	
			Metrics
			,
			 
			learner
			)
			
			
    
			
			
			
			length
			(
			
			metrics
			.
			
			metrics
			)
			 
			==
			 
			
			length
			(
			
			unique
			(
			

	
			metricname
			.
			
			(
			
			metrics
			.
			
			metrics
			)
			)
			)
			 
			||
			
        
			
			error
			(
			
			"
			Multiple metrics have the same name!
			"
			)
			
    
			
			if
			 
			
			!
			
			haskey
			(
			
			learner
			.
			
			cbstate
			,
			 
			
			:
			metricsstep
			)
			
			
        
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			 
			=
			 
			
			
			DefaultDict
			{

	
			Phase
			,
			MVHistory
			}
			(
			
			
			(
			)
			 
			->
			 
			
			MVHistory
			(
			)
			)
			
    
			end
			
    
			
			if
			 
			
			!
			
			haskey
			(
			
			learner
			.
			
			cbstate
			,
			 
			
			:
			metricsepoch
			)
			
			
        
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsepoch
			 
			=
			 
			
			
			DefaultDict
			{

	
			Phase
			,
			MVHistory
			}
			(
			
			
			(
			)
			 
			->
			 
			
			MVHistory
			(
			)
			)
			
    
			end
			

			end
			

			

			

			

			
			

	
			on
			(
			
			::

	
			EpochBegin
			,
			 
			
			::

	
			Phase
			,
			 
			
			metrics
			::

	
			Metrics
			,
			 
			learner
			)
			 
			=
			 
			
			foreach
			(

	
			reset!
			,
			 
			
			metrics
			.
			
			metrics
			)
			

			

			
			function
			 
			

	
			on
			(
			
			::

	
			StepEnd
			,
			 
			phase
			,
			 
			
			metrics
			::

	
			Metrics
			,
			 
			learner
			)
			
			
    
			
			metricsstep
			 
			=
			 
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			[
			phase
			]
			
    
			
			step
			 
			=
			 
			
			
			
			
			learner
			.
			
			cbstate
			.
			
			history
			[
			phase
			]
			.
			
			steps
			
    
			
			for
			
			 
			metric
			 
			in
			 
			
			metrics
			.
			
			metrics
			
			
        
			

	
			step!
			(
			metric
			,
			 
			learner
			,
			 
			phase
			)
			
        
			
			val
			 
			=
			 
			

	
			stepvalue
			(
			metric
			)
			
        
			
			if
			
			 
			val
			 
			!==
			 
			nothing
			
			
            
			
			metricsstep
			 
			=
			 
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			[
			phase
			]
			
            
			
			push!
			(
			metricsstep
			,
			 
			
			Symbol
			(
			

	
			metricname
			(
			metric
			)
			)
			,
			 
			step
			,
			 
			val
			)
			
        
			end
			
    
			end
			

			end
			

			

			
			function
			 
			

	
			on
			(
			
			::

	
			EpochEnd
			,
			 
			phase
			,
			 
			
			metrics
			::

	
			Metrics
			,
			 
			learner
			)
			
			
    
			
			epoch
			 
			=
			 
			
			
			
			
			learner
			.
			
			cbstate
			.
			
			history
			[
			phase
			]
			.
			
			epochs
			
    
			
			for
			
			 
			metric
			 
			in
			 
			
			metrics
			.
			
			metrics
			
			
        
			
			val
			 
			=
			 
			

	
			epochvalue
			(
			metric
			)
			
        
			
			if
			
			 
			val
			 
			!==
			 
			nothing
			
			
            
			
			metricsepoch
			 
			=
			 
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsepoch
			[
			phase
			]
			
            
			
			push!
			(
			metricsepoch
			,
			 
			
			Symbol
			(
			

	
			metricname
			(
			metric
			)
			)
			,
			 
			epoch
			,
			 
			val
			)
			
        
			end
			
    
			end
			

			end

AbstractMetric interface


			
			
			
			
			
			"""
			

			    abstract type AbstractMetric

			

			Abstract type for metrics passed to [`Metrics`](#).

			

			For most use cases, you should use [`Metric`](#), the standard

			implementation.

			

			## Interface

			

			If [`Metric`](#) doesn't fit your use case, you can create

			a new subtype of `AbstractMetric` and implement the following

			methods to make it compatible with [`Metrics`](#):

			

			- [`reset!`](#)`(metric)`

			- [`step!`](#)`(metric, learner)`

			- [`stepvalue`](#)`(metric)`

			- [`epochvalue`](#)`(metric)`

			- [`metricname`](#)`(metric)`

			"""
			

			
			abstract
			 
			type
			 

	
			AbstractMetric
			 
			end
			

			

			
			

	
			step!
			(
			metric
			,
			 
			learner
			,
			 
			_
			)
			 
			=
			 
			

	
			step!
			(
			metric
			,
			 
			learner
			)
			

			

			
			mutable
			
			 
			struct
			
			 
			

	
			Metric
			{
			T
			}
			 
			<:
			 

	
			AbstractMetric
			
			
    
			
			metricfn
			::
			Any
			
    
			
			statistic
			::
			
			OnlineStat
			{
			T
			}
			
    
			
			_statistic
			::
			Any
			
    
			
			name
			::
			Any
			
    
			
			device
			::
			Any
			
    
			
			P
			::
			Any
			
    
			
			last
			::
			
			Union
			{
			Nothing
			,
			T
			}
			

			end
			

			

			
			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			metric
			::
			

	
			Metric
			{
			T
			}
			)
			 
			where
			 
			{
			T
			}
			 
			=
			 
			
			print
			(
			io
			,
			 
			
			"
			Metric(
			"
			,
			 
			
			metric
			.
			
			name
			,
			 
			
			"
			)
			"
			)
			

			

			
			
			
			"""
			

			    Metric(metricfn[; statistic, device, name])

			

			Implementation of [`AbstractMetric`](#) that can be used with the

			[`Metrics`](#) callback.

			

			

			## Arguments

			

			Positional:

			

			- `metricfn(ŷs, ys)` should return a number.

			

			Keyword:

			

			- `statistic` is a `OnlineStats.Statistic` that is updated after every step.

			    The default is `OnlineStats.Mean()`

			- `name` is used for printing.

			- `device` is a function applied to `ŷs` and `ys`

			    before passing them to `metricfn`. The default is `Flux.cpu` so that

			    the callback works if `metricfn` doesn't support arrays from other device

			    types. If, for example, `metricfn` works on `CurArray`s, you can pass

			    `device = Flux.gpu`.

			- `phase = Phase`: a (sub)type of [`Phase`](#) that restricts for which phases the

			    metric is computed.

			

			## Examples

			

			- `Metric(accuracy)`

			- `Metric(Flux.mse, device = gpu, name = "Mean Squared Error")`

			- `Metric(Flux.mae, device = gpu)`

			

			```julia

			cb = Metric(Flux.mse, device = gpu, name = "Mean Squared Error")

			```

			

			If a metric is expensive to compute and you don't want it to slow down the

			training phase, you can compute it on the validation phase only:

			

			```julia

			cb = Metric(expensivemetric, P = ValidationPhase)

			```

			"""
			

			
			function
			 
			

	
			Metric
			(
			
    
			metricfn
			
			;
			
    
			
			name
			 
			=
			 
			
			uppercasefirst
			(
			
			string
			(
			metricfn
			)
			)
			,
			
    
			
			statistic
			 
			=
			 
			
			Mean
			(
			)
			,
			
    
			
			device
			 
			=
			 

	
			cpu
			,
			
    
			
			phase
			 
			=
			 

	
			Phase
			,
			

			)
			
			

			
    
			
			return
			 
			

	
			Metric
			(
			metricfn
			,
			 
			
			deepcopy
			(
			statistic
			)
			,
			 
			statistic
			,
			 
			name
			,
			 
			device
			,
			 
			phase
			,
			 
			nothing
			)
			

			end
			

			

			

			
			function
			 
			

	
			reset!
			(
			
			metric
			::

	
			Metric
			)
			
			
    
			
			
			metric
			.
			
			statistic
			 
			=
			 
			
			deepcopy
			(
			
			metric
			.
			
			_statistic
			)
			

			end
			

			

			
			function
			 
			

	
			step!
			(
			
			metric
			::

	
			Metric
			,
			 
			learner
			,
			 
			phase
			)
			
			
    
			
			if
			
			 
			phase
			 
			isa
			 
			
			metric
			.
			
			P
			
			
        
			
			
			ŷs
			,
			 
			ys
			 
			=
			 
			
			
			metric
			.
			
			device
			(
			
			(
			
			
			learner
			.
			
			step
			.
			
			ŷs
			,
			 
			
			
			learner
			.
			
			step
			.
			
			ys
			)
			)
			
        
			
			
			metric
			.
			
			last
			 
			=
			 
			
			
			metric
			.
			
			metricfn
			(
			ŷs
			,
			 
			ys
			)
			
        
			
			
			OnlineStats
			.
			

	
			fit!
			(
			
			metric
			.
			
			statistic
			,
			 
			
			metric
			.
			
			last
			)
			
    
			else
			
			
        
			
			
			metric
			.
			
			last
			 
			=
			 
			nothing
			
    
			end
			

			end
			

			

			
			

	
			stepvalue
			(
			
			metric
			::

	
			Metric
			)
			 
			=
			 
			
			metric
			.
			
			last
			

			
			function
			 
			

	
			epochvalue
			(
			
			metric
			::

	
			Metric
			)
			
			
    
			
			if
			 
			
			isnothing
			(
			
			metric
			.
			
			last
			)
			
			
        
			nothing
			
    
			else
			
			
        
			
			
			OnlineStats
			.
			
			value
			(
			
			metric
			.
			
			statistic
			)
			
    
			end
			

			end
			

			
			

	
			metricname
			(
			
			metric
			::

	
			Metric
			)
			 
			=
			 
			
			metric
			.
			
			name

Loss Metric


			
			
			
			mutable
			
			 
			struct
			
			 

	
			Loss
			 
			<:
			 

	
			AbstractMetric
			
			
    
			
			statistic
			::
			Any
			
    
			
			_statistic
			::
			Any
			
    
			
			last
			::
			Any
			
    
			
			name
			::
			Any
			

			end
			

			

			

			
			function
			 
			

	
			Loss
			(
			
			weight
			 
			=
			 
			
			EqualWeight
			(
			)
			
			;
			 
			
			name
			 
			=
			 
			
			"
			Loss
			"
			)
			
			
    
			
			stat
			 
			=
			 
			
			Mean
			(
			
			weight
			 
			=
			 
			weight
			)
			
    
			
			return
			 
			

	
			Loss
			(
			
			deepcopy
			(
			stat
			)
			,
			 
			stat
			,
			 
			nothing
			,
			 
			name
			)
			

			end
			

			

			

			
			function
			 
			

	
			reset!
			(
			
			loss
			::

	
			Loss
			)
			
			
    
			
			
			loss
			.
			
			statistic
			 
			=
			 
			
			deepcopy
			(
			
			loss
			.
			
			_statistic
			)
			

			end
			

			

			

			
			function
			 
			

	
			step!
			(
			
			metric
			::

	
			Loss
			,
			 
			learner
			,
			 
			_
			)
			
			
    
			
			
			metric
			.
			
			last
			 
			=
			 
			
			
			learner
			.
			
			step
			.
			
			loss
			
    
			
			
			OnlineStats
			.
			

	
			fit!
			(
			
			metric
			.
			
			statistic
			,
			 
			
			metric
			.
			
			last
			)
			

			end
			

			

			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			::

	
			Loss
			)
			 
			=
			 
			
			print
			(
			io
			,
			 
			
			"
			Loss()
			"
			)
			

			

			

			
			

	
			stepvalue
			(
			
			metric
			::

	
			Loss
			)
			 
			=
			 
			
			metric
			.
			
			last
			

			
			

	
			epochvalue
			(
			
			metric
			::

	
			Loss
			)
			 
			=
			 
			
			
			OnlineStats
			.
			
			value
			(
			
			metric
			.
			
			statistic
			)
			

			
			

	
			metricname
			(
			
			metric
			::

	
			Loss
			)
			 
			=
			 
			
			metric
			.
			
			name

Utility


			
			
			
			function
			 
			

	
			SmoothLoss
			(
			
			β
			 
			=
			 
			0.02
			)
			
			
    
			
			return
			 
			

	
			Loss
			(
			
			
			OnlineStats
			.
			
			ExponentialWeight
			(
			
			1
			 
			-
			 
			β
			)
			,
			 
			
			name
			 
			=
			 
			
			"
			SmoothLoss
			"
			)
			

			end
			

			

			
			@
			testset
			 
			
			"
			Metric
			"
			 
			
			begin
			
    
			
			cb
			 
			=
			 
			

	
			Metrics
			(
			

	
			Metric
			(

	
			accuracy
			,
			 
			
			phase
			 
			=
			 

	
			ValidationPhase
			)
			)
			
    
			
			learner
			 
			=
			 
			

	
			testlearner
			(
			

	
			Recorder
			(
			)
			,
			 
			cb
			)
			
    
			
			@
			test_nowarn
			 
			

	
			fit!
			(
			learner
			,
			 
			1
			)
			
    
			
			@
			test
			
			 
			
			:
			Accuracy
			 
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			[
			

	
			ValidationPhase
			(
			)
			]
			)
			
    
			
			@
			test
			 
			
			!
			(
			
			
			:
			Accuracy
			 
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			[
			

	
			TrainingPhase
			(
			)
			]
			)
			)
			
    
			
			@
			test
			
			 
			
			:
			Accuracy
			 
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsepoch
			[
			

	
			ValidationPhase
			(
			)
			]
			)
			
    
			
			@
			test
			 
			
			!
			(
			
			
			:
			Accuracy
			 
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsepoch
			[
			

	
			TrainingPhase
			(
			)
			]
			)
			)
			

			end