Metric
struct defined in module
FluxTraining
Metric(metricfn[; statistic, device, name])
Implementation of
AbstractMetric that can be used with the
Metrics callback.
Positional:
metricfn(ŷs, ys) should return a number.
Keyword:
statistic is a
OnlineStats.Statistic that is updated after every step. The default is
OnlineStats.Mean()
name is used for printing.
device is a function applied to
ŷs and
ys before passing them to
metricfn. The default is
Flux.cpu so that the callback works if
metricfn doesn't support arrays from other device types. If, for example,
metricfn works on
CurArrays, you can pass
device = Flux.gpu.
phase = Phase: a (sub)type of
Phase
that restricts for which phases the metric is computed.
Metric(accuracy)
Metric(Flux.mse, device = gpu, name = "Mean Squared Error")
Metric(Flux.mae, device = gpu)
cb
=
Metric
(
Flux
.
mse
,
device
=
gpu
,
name
=
"
Mean Squared Error
"
)If a metric is expensive to compute and you don't want it to slow down the training phase, you can compute it on the validation phase only:
cb
=
Metric
(
expensivemetric
,
P
=
ValidationPhase
)There are
2
methods for FluxTraining.Metric:
The following pages link back here: