using
FluxTraining
using
FluxTraining
:
Callback
,
Read
,
Write
,
stateaccess
model
,
data
,
lossfn
=
nothing
,
(
)
,
nothing
,
nothing
Callbacks allow injecting functionality at many points during the training loop.
To use them, simply pass each callback to
Learner
:
learner
=
Learner
(
model
,
lossfn
,
# required arguments
callbacks
=
[
ToGPU
(
)
,
Metrics
(
accuracy
)
]
,
data
=
data
)
# pass any number of callbacks as additional arguments
Learner()
Some useful callbacks are added by default:
learner
.
callbacks
.
cbs
FluxTraining.SafeCallback[ToDevice(Flux.gpu, Flux.gpu), Metrics(Loss(), Metric(Accuracy)), ProgressPrinter(), MetricsPrinter(), StopOnNaNLoss(), Recorder()]
See callback reference for a list of all callbacks included in FluxTraining . jl and their documentation.
The order the callbacks are passed in doesn’t matter. FluxTraining . jl creates a dependency graph that makes sure the callbacks are run in the correct order. Read custom callbacks to find out how to create callbacks yourself.
The following pages link back here: