Metric
struct
defined in module
FluxTraining
Metric(metricfn[; statistic, device, name])
Implementation of
AbstractMetric
that can be used with the
Metrics
callback.
Positional:
metricfn(ŷs, ys)
should return a number.
Keyword:
statistic
is a
OnlineStats.Statistic
that is updated after every step. The default is
OnlineStats.Mean()
name
is used for printing.
device
is a function applied to
ŷs
and
ys
before passing them to
metricfn
. The default is
Flux.cpu
so that the callback works if
metricfn
doesn't support arrays from other device types. If, for example,
metricfn
works on
CurArray
s, you can pass
device = Flux.gpu
.
phase = Phase
: a (sub)type of
Phase
that restricts for which phases the metric is computed.
Metric(accuracy)
Metric(Flux.mse, device = gpu, name = "Mean Squared Error")
Metric(Flux.mae, device = gpu)
cb
=
Metric
(
Flux
.
mse
,
device
=
gpu
,
name
=
"
Mean Squared Error
"
)
If a metric is expensive to compute and you don't want it to slow down the training phase, you can compute it on the validation phase only:
cb
=
Metric
(
expensivemetric
,
P
=
ValidationPhase
)
There are
2
methods for FluxTraining.Metric
:
The following pages link back here: