FluxTraining
			
			
			
			
			
			"""
			
			    Metrics(metrics...) <: Callback
			
			Callback that tracks metrics during training.
			
			You can pass any number of `metrics` with every argument being
			- an [`AbstractMetric`](#) like [`Metric`](#); or
			- a function `f(ŷs, ys) -> val`
			
			This callback is added by default to every [`Learner`](#) unless you pass in
			`usedefaultcallbacks = false`. A metric tracking `learner.lossfn` [`Loss`](#)
			is included by default.
			
			The computed metrics can be access in `learner.cbstate.metricsstep` and
			`learner.cbstate.metricsepoch` for steps and epochs, respectively.
			
			## Examples
			
			Track [`accuracy`](#):
			
			```julia
			cb = Metrics(accuracy)
			```
			
			Pass in [`Metric`]s:
			
			```julia
			cb = Metrics(
			    Metric(Flux.mse, device = gpu),
			    Metric(Flux.mae, device = gpu)
			)
			```
			
			"""
			
			
			
			struct
			
			 
	
			Metrics
			 
			<:
			 
	
			Callback
			
			
    
			
			metrics
			::
			Tuple
			
    
			
			function
			 
			
	
			Metrics
			(
			
			metrics
			...
			)
			
			
        
			
			return
			 
			
			new
			(
			
            
			
			Tuple
			(
			
			
			
			m
			 
			isa
			 
	
			AbstractMetric
			 
			?
			 
			m
			 
			:
			 
			
	
			Metric
			(
			m
			)
			 
			for
			
			 
			m
			 
			in
			 
			
			(
			
	
			Loss
			(
			)
			,
			 
			
			metrics
			...
			)
			)
			,
			
        
			)
			
    
			end
			
			end
			
			
			
			
	
			runafter
			(
			
			::
	
			Metrics
			)
			 
			=
			 
			
			(
	
			Recorder
			,
			)
			
			
			
	
			stateaccess
			(
			
			::
	
			Metrics
			)
			 
			=
			 
			
			(
			
    
			
			cbstate
			 
			=
			 
			
			(
			
			metricsstep
			 
			=
			 
			
	
			Write
			(
			)
			,
			 
			
			metricsepoch
			 
			=
			 
			
	
			Write
			(
			)
			,
			 
			
			history
			 
			=
			 
			
	
			Read
			(
			)
			)
			,
			
    
			
			step
			 
			=
			 
			
	
			Read
			(
			)
			,
			
			)
			
			
			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			metrics
			::
	
			Metrics
			)
			 
			=
			
    
			
			print
			(
			io
			,
			 
			
			"
			Metrics(
			"
			,
			 
			
			join
			(
			
			string
			.
			
			(
			
			metrics
			.
			
			metrics
			)
			,
			 
			
			"
			, 
			"
			)
			,
			 
			
			"
			)
			"
			)
			store metrics in 
			cbstate so other callbacks can access them
			
			
			
			function
			 
			
	
			init!
			(
			
			metrics
			::
	
			Metrics
			,
			 
			learner
			)
			
			
    
			
			
			
			length
			(
			
			metrics
			.
			
			metrics
			)
			 
			==
			 
			
			length
			(
			
			unique
			(
			
	
			metricname
			.
			
			(
			
			metrics
			.
			
			metrics
			)
			)
			)
			 
			||
			
        
			
			error
			(
			
			"
			Multiple metrics have the same name!
			"
			)
			
    
			
			if
			 
			
			!
			
			haskey
			(
			
			learner
			.
			
			cbstate
			,
			 
			
			:
			metricsstep
			)
			
			
        
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			 
			=
			 
			
			
			DefaultDict
			{
	
			Phase
			,
			MVHistory
			}
			(
			
			
			(
			)
			 
			->
			 
			
			MVHistory
			(
			)
			)
			
    
			end
			
    
			
			if
			 
			
			!
			
			haskey
			(
			
			learner
			.
			
			cbstate
			,
			 
			
			:
			metricsepoch
			)
			
			
        
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsepoch
			 
			=
			 
			
			
			DefaultDict
			{
	
			Phase
			,
			MVHistory
			}
			(
			
			
			(
			)
			 
			->
			 
			
			MVHistory
			(
			)
			)
			
    
			end
			
			end
			
			
			
			
			
			
	
			on
			(
			
			::
	
			EpochBegin
			,
			 
			
			::
	
			Phase
			,
			 
			
			metrics
			::
	
			Metrics
			,
			 
			learner
			)
			 
			=
			 
			
			foreach
			(
	
			reset!
			,
			 
			
			metrics
			.
			
			metrics
			)
			
			
			
			function
			 
			
	
			on
			(
			
			::
	
			StepEnd
			,
			 
			phase
			,
			 
			
			metrics
			::
	
			Metrics
			,
			 
			learner
			)
			
			
    
			
			metricsstep
			 
			=
			 
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			[
			phase
			]
			
    
			
			step
			 
			=
			 
			
			
			
			
			learner
			.
			
			cbstate
			.
			
			history
			[
			phase
			]
			.
			
			steps
			
    
			
			for
			
			 
			metric
			 
			in
			 
			
			metrics
			.
			
			metrics
			
			
        
			
	
			step!
			(
			metric
			,
			 
			learner
			,
			 
			phase
			)
			
        
			
			val
			 
			=
			 
			
	
			stepvalue
			(
			metric
			)
			
        
			
			if
			
			 
			val
			 
			!==
			 
			nothing
			
			
            
			
			metricsstep
			 
			=
			 
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			[
			phase
			]
			
            
			
			push!
			(
			metricsstep
			,
			 
			
			Symbol
			(
			
	
			metricname
			(
			metric
			)
			)
			,
			 
			step
			,
			 
			val
			)
			
        
			end
			
    
			end
			
			end
			
			
			
			function
			 
			
	
			on
			(
			
			::
	
			EpochEnd
			,
			 
			phase
			,
			 
			
			metrics
			::
	
			Metrics
			,
			 
			learner
			)
			
			
    
			
			epoch
			 
			=
			 
			
			
			
			
			learner
			.
			
			cbstate
			.
			
			history
			[
			phase
			]
			.
			
			epochs
			
    
			
			for
			
			 
			metric
			 
			in
			 
			
			metrics
			.
			
			metrics
			
			
        
			
			val
			 
			=
			 
			
	
			epochvalue
			(
			metric
			)
			
        
			
			if
			
			 
			val
			 
			!==
			 
			nothing
			
			
            
			
			metricsepoch
			 
			=
			 
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsepoch
			[
			phase
			]
			
            
			
			push!
			(
			metricsepoch
			,
			 
			
			Symbol
			(
			
	
			metricname
			(
			metric
			)
			)
			,
			 
			epoch
			,
			 
			val
			)
			
        
			end
			
    
			end
			
			endAbstractMetric interface
			
			
			
			
			
			"""
			
			    abstract type AbstractMetric
			
			Abstract type for metrics passed to [`Metrics`](#).
			
			For most use cases, you should use [`Metric`](#), the standard
			implementation.
			
			## Interface
			
			If [`Metric`](#) doesn't fit your use case, you can create
			a new subtype of `AbstractMetric` and implement the following
			methods to make it compatible with [`Metrics`](#):
			
			- [`reset!`](#)`(metric)`
			- [`step!`](#)`(metric, learner)`
			- [`stepvalue`](#)`(metric)`
			- [`epochvalue`](#)`(metric)`
			- [`metricname`](#)`(metric)`
			"""
			
			
			abstract
			 
			type
			 
	
			AbstractMetric
			 
			end
			
			
			
			
	
			step!
			(
			metric
			,
			 
			learner
			,
			 
			_
			)
			 
			=
			 
			
	
			step!
			(
			metric
			,
			 
			learner
			)
			
			
			
			mutable
			
			 
			struct
			
			 
			
	
			Metric
			{
			T
			}
			 
			<:
			 
	
			AbstractMetric
			
			
    
			
			metricfn
			::
			Any
			
    
			
			statistic
			::
			
			OnlineStat
			{
			T
			}
			
    
			
			_statistic
			::
			Any
			
    
			
			name
			::
			Any
			
    
			
			device
			::
			Any
			
    
			
			P
			::
			Any
			
    
			
			last
			::
			
			Union
			{
			Nothing
			,
			T
			}
			
			end
			
			
			
			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			metric
			::
			
	
			Metric
			{
			T
			}
			)
			 
			where
			 
			{
			T
			}
			 
			=
			 
			
			print
			(
			io
			,
			 
			
			"
			Metric(
			"
			,
			 
			
			metric
			.
			
			name
			,
			 
			
			"
			)
			"
			)
			
			
			
			
			
			"""
			
			    Metric(metricfn[; statistic, device, name])
			
			Implementation of [`AbstractMetric`](#) that can be used with the
			[`Metrics`](#) callback.
			
			
			## Arguments
			
			Positional:
			
			- `metricfn(ŷs, ys)` should return a number.
			
			Keyword:
			
			- `statistic` is a `OnlineStats.Statistic` that is updated after every step.
			    The default is `OnlineStats.Mean()`
			- `name` is used for printing.
			- `device` is a function applied to `ŷs` and `ys`
			    before passing them to `metricfn`. The default is `Flux.cpu` so that
			    the callback works if `metricfn` doesn't support arrays from other device
			    types. If, for example, `metricfn` works on `CurArray`s, you can pass
			    `device = Flux.gpu`.
			- `phase = Phase`: a (sub)type of [`Phase`](#) that restricts for which phases the
			    metric is computed.
			
			## Examples
			
			- `Metric(accuracy)`
			- `Metric(Flux.mse, device = gpu, name = "Mean Squared Error")`
			- `Metric(Flux.mae, device = gpu)`
			
			```julia
			cb = Metric(Flux.mse, device = gpu, name = "Mean Squared Error")
			```
			
			If a metric is expensive to compute and you don't want it to slow down the
			training phase, you can compute it on the validation phase only:
			
			```julia
			cb = Metric(expensivemetric, P = ValidationPhase)
			```
			"""
			
			
			function
			 
			
	
			Metric
			(
			
    
			metricfn
			
			;
			
    
			
			name
			 
			=
			 
			
			uppercasefirst
			(
			
			string
			(
			metricfn
			)
			)
			,
			
    
			
			statistic
			 
			=
			 
			
			Mean
			(
			)
			,
			
    
			
			device
			 
			=
			 
	
			cpu
			,
			
    
			
			phase
			 
			=
			 
	
			Phase
			,
			
			)
			
			
			
    
			
			return
			 
			
	
			Metric
			(
			metricfn
			,
			 
			
			deepcopy
			(
			statistic
			)
			,
			 
			statistic
			,
			 
			name
			,
			 
			device
			,
			 
			phase
			,
			 
			nothing
			)
			
			end
			
			
			
			
			function
			 
			
	
			reset!
			(
			
			metric
			::
	
			Metric
			)
			
			
    
			
			
			metric
			.
			
			statistic
			 
			=
			 
			
			deepcopy
			(
			
			metric
			.
			
			_statistic
			)
			
			end
			
			
			
			function
			 
			
	
			step!
			(
			
			metric
			::
	
			Metric
			,
			 
			learner
			,
			 
			phase
			)
			
			
    
			
			if
			
			 
			phase
			 
			isa
			 
			
			metric
			.
			
			P
			
			
        
			
			
			ŷs
			,
			 
			ys
			 
			=
			 
			
			
			metric
			.
			
			device
			(
			
			(
			
			
			learner
			.
			
			step
			.
			
			ŷs
			,
			 
			
			
			learner
			.
			
			step
			.
			
			ys
			)
			)
			
        
			
			
			metric
			.
			
			last
			 
			=
			 
			
			
			metric
			.
			
			metricfn
			(
			ŷs
			,
			 
			ys
			)
			
        
			
			
			OnlineStats
			.
			
	
			fit!
			(
			
			metric
			.
			
			statistic
			,
			 
			
			metric
			.
			
			last
			)
			
    
			else
			
			
        
			
			
			metric
			.
			
			last
			 
			=
			 
			nothing
			
    
			end
			
			end
			
			
			
			
	
			stepvalue
			(
			
			metric
			::
	
			Metric
			)
			 
			=
			 
			
			metric
			.
			
			last
			
			
			function
			 
			
	
			epochvalue
			(
			
			metric
			::
	
			Metric
			)
			
			
    
			
			if
			 
			
			isnothing
			(
			
			metric
			.
			
			last
			)
			
			
        
			nothing
			
    
			else
			
			
        
			
			
			OnlineStats
			.
			
			value
			(
			
			metric
			.
			
			statistic
			)
			
    
			end
			
			end
			
			
			
	
			metricname
			(
			
			metric
			::
	
			Metric
			)
			 
			=
			 
			
			metric
			.
			
			nameLoss Metric
			
			
			
			mutable
			
			 
			struct
			
			 
	
			Loss
			 
			<:
			 
	
			AbstractMetric
			
			
    
			
			statistic
			::
			Any
			
    
			
			_statistic
			::
			Any
			
    
			
			last
			::
			Any
			
    
			
			name
			::
			Any
			
			end
			
			
			
			
			function
			 
			
	
			Loss
			(
			
			weight
			 
			=
			 
			
			EqualWeight
			(
			)
			
			;
			 
			
			name
			 
			=
			 
			
			"
			Loss
			"
			)
			
			
    
			
			stat
			 
			=
			 
			
			Mean
			(
			
			weight
			 
			=
			 
			weight
			)
			
    
			
			return
			 
			
	
			Loss
			(
			
			deepcopy
			(
			stat
			)
			,
			 
			stat
			,
			 
			nothing
			,
			 
			name
			)
			
			end
			
			
			
			
			function
			 
			
	
			reset!
			(
			
			loss
			::
	
			Loss
			)
			
			
    
			
			
			loss
			.
			
			statistic
			 
			=
			 
			
			deepcopy
			(
			
			loss
			.
			
			_statistic
			)
			
			end
			
			
			
			
			function
			 
			
	
			step!
			(
			
			metric
			::
	
			Loss
			,
			 
			learner
			,
			 
			_
			)
			
			
    
			
			
			metric
			.
			
			last
			 
			=
			 
			
			
			learner
			.
			
			step
			.
			
			loss
			
    
			
			
			OnlineStats
			.
			
	
			fit!
			(
			
			metric
			.
			
			statistic
			,
			 
			
			metric
			.
			
			last
			)
			
			end
			
			
			
			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			::
	
			Loss
			)
			 
			=
			 
			
			print
			(
			io
			,
			 
			
			"
			Loss()
			"
			)
			
			
			
			
			
	
			stepvalue
			(
			
			metric
			::
	
			Loss
			)
			 
			=
			 
			
			metric
			.
			
			last
			
			
			
	
			epochvalue
			(
			
			metric
			::
	
			Loss
			)
			 
			=
			 
			
			
			OnlineStats
			.
			
			value
			(
			
			metric
			.
			
			statistic
			)
			
			
			
	
			metricname
			(
			
			metric
			::
	
			Loss
			)
			 
			=
			 
			
			metric
			.
			
			nameUtility
			
			
			
			function
			 
			
	
			SmoothLoss
			(
			
			β
			 
			=
			 
			0.02
			)
			
			
    
			
			return
			 
			
	
			Loss
			(
			
			
			OnlineStats
			.
			
			ExponentialWeight
			(
			
			1
			 
			-
			 
			β
			)
			,
			 
			
			name
			 
			=
			 
			
			"
			SmoothLoss
			"
			)
			
			end
			
			
			
			@
			testset
			 
			
			"
			Metric
			"
			 
			
			begin
			
    
			
			cb
			 
			=
			 
			
	
			Metrics
			(
			
	
			Metric
			(
	
			accuracy
			,
			 
			
			phase
			 
			=
			 
	
			ValidationPhase
			)
			)
			
    
			
			learner
			 
			=
			 
			
	
			testlearner
			(
			
	
			Recorder
			(
			)
			,
			 
			cb
			)
			
    
			
			@
			test_nowarn
			 
			
	
			fit!
			(
			learner
			,
			 
			1
			)
			
    
			
			@
			test
			
			 
			
			:
			Accuracy
			 
			∈
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			[
			
	
			ValidationPhase
			(
			)
			]
			)
			
    
			
			@
			test
			 
			
			!
			(
			
			
			:
			Accuracy
			 
			∈
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsstep
			[
			
	
			TrainingPhase
			(
			)
			]
			)
			)
			
    
			
			@
			test
			
			 
			
			:
			Accuracy
			 
			∈
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsepoch
			[
			
	
			ValidationPhase
			(
			)
			]
			)
			
    
			
			@
			test
			 
			
			!
			(
			
			
			:
			Accuracy
			 
			∈
			 
			
			keys
			(
			
			
			
			learner
			.
			
			cbstate
			.
			
			metricsepoch
			[
			
	
			TrainingPhase
			(
			)
			]
			)
			)
			
			end