Metric

struct defined in module FluxTraining


			Metric(metricfn[; statistic, device, name])

Implementation of AbstractMetric that can be used with the
Metrics callback.

Arguments

Positional:

  • metricfn(ŷs, ys) should return a number.

Keyword:

  • statistic is a OnlineStats.Statistic that is updated after every step.
    The default is OnlineStats.Mean()

  • name is used for printing.

  • device is a function applied to ŷs and ys
    before passing them to metricfn . The default is Flux.cpu so that
    the callback works if metricfn doesn’t support arrays from other device
    types. If, for example, metricfn works on CurArray s, you can pass
    device = Flux.gpu .

  • phase = Phase : a (sub)type of Phase that restricts for which phases the
    metric is computed.

Examples

  • Metric(accuracy)

  • Metric(Flux.mse, device = gpu, name = "Mean Squared Error")

  • Metric(Flux.mae, device = gpu)


			
			
			
			cb
			
			 
			
			=
			
			 
			
			
	
		
			Metric
			
			(
			
			
			Flux
			
			.
			
			
			mse
			
			,
			
			 
			
			
			device
			
			 
			
			=
			
			 
			
			gpu
			
			,
			
			 
			
			
			name
			
			 
			
			=
			
			 
			
			"
			
			Mean Squared Error
			
			"
			
			)

If a metric is expensive to compute and you don’t want it to slow down the
training phase, you can compute it on the validation phase only:


			
			
			
			cb
			
			 
			
			=
			
			 
			
			
	
		
			Metric
			
			(
			
			expensivemetric
			
			,
			
			 
			
			
			P
			
			 
			
			=
			
			 
			
	
		
			ValidationPhase
			
			)
Methods

There are 2 methods for Metric:

Metric(metricfn; name, statistic, device, phase)
callbacks/metrics.jl:177
Metric(metricfn, statistic::OnlineStatsBase.OnlineStat{T}, _statistic, name, device, P, last::Union{Nothing, T})
callbacks/metrics.jl:123
Backlinks