Metric

struct defined in module FluxTraining


			Metric(metricfn[; statistic, device, name])

Implementation of AbstractMetric that can be used with the Metrics callback.

Arguments

Positional:

  • metricfn(ŷs, ys) should return a number.

Keyword:

  • statistic is a OnlineStats.Statistic that is updated after every step. The default is OnlineStats.Mean()

  • name is used for printing.

  • device is a function applied to ŷs and ys before passing them to metricfn. The default is Flux.cpu so that the callback works if metricfn doesn't support arrays from other device types. If, for example, metricfn works on CurArrays, you can pass device = Flux.gpu.

  • phase = Phase: a (sub)type of Phase that restricts for which phases the metric is computed.

Examples

  • Metric(accuracy)

  • Metric(Flux.mse, device = gpu, name = "Mean Squared Error")

  • Metric(Flux.mae, device = gpu)


			
			
			
			cb
			 
			=
			 
			

			Metric
			(
			

	
			Flux
			.
			
			mse
			,
			 
			
			device
			 
			=
			 

	
			gpu
			,
			 
			
			name
			 
			=
			 
			
			"
			Mean Squared Error
			"
			)

If a metric is expensive to compute and you don't want it to slow down the training phase, you can compute it on the validation phase only:


			
			
			
			cb
			 
			=
			 
			

			Metric
			(
			expensivemetric
			,
			 
			
			P
			 
			=
			 

	
			ValidationPhase
			)
Methods

There are 2 methods for FluxTraining.Metric: