Learner
struct
defined in module
FluxTraining
Learner(model, lossfn; [callbacks = [], optimizer = ADAM(), kwargs...])
Holds and coordinates all state of the training.
model
is trained by
optimizing
lossfn
with
optimizer
on
data
.
Positional arguments:
model
: A Flux.jl model or a
NamedTuple
of models.
lossfn
: Loss function with signature
lossfn(model(x), y) -> Number
.
Keyword arguments (optional):
data = ()
: Data iterators. A 2-tuple will be treated as
(trainingdataiter, validdataiter)
.
You can also pass in an empty tuple
()
and use the
epoch!
method with a
dataiter
as third argument.
A data iterator is an iterable over batches. For regular supervised training,
each batch should be a tuple
(xs, ys)
.
optimizer = ADAM()
: The optimizer used to update the
model
’s weights
callbacks = []
: A list of callbacks that should be used. If
usedefaultcallbacks == true
,
this will be extended by the default callbacks
usedefaultcallbacks = true
: Whether to add some basic callbacks. Included
are
Metrics
,
Recorder
,
ProgressPrinter
,
StopOnNaNLoss
, and
MetricsPrinter
.
cbrunner = LinearRunner()
: Callback runner to use.
(Use this as a reference when implementing callbacks)
model
,
optimizer
, and
lossfn
are stored as passed in
data
is a
PropDict
of data iterators, usually
:training
and
:validation
.
params
: An instance of
model
’s parameters of type
Flux.Params
. If
model
is
a
NamedTuple
, then
params
is a
NamedTuple
as well.
step::
PropDict
: State of the last step. Contents depend on the last run
Phase
.
cbstate::
PropDict
: Special state container that callbacks can
save state to for other callbacks. Its keys depend on what callbacks
are being used. See the
custom callbacks guide
for more info.
Learner(model, data, optimizer, lossfn, [callbacks...; kwargs...])
There are
4
methods for Learner
:
Learner(model, data::FluxTraining.PropDict, optimizer, lossfn, params, step::FluxTraining.PropDict, callbacks::FluxTraining.Callbacks, cbstate::FluxTraining.PropDict)
Learner(model, lossfn; callbacks, data, optimizer, kwargs...)
Learner(model, data, optimizer, lossfn, callbacks::FluxTraining.SafeCallback...; usedefaultcallbacks, cbrunner)
Learner(model, data, optimizer, lossfn, params, step, callbacks, cbstate)
The following pages link back here:
Callback reference, Custom callbacks, Data iterators, Features, FluxTraining.jl, Getting started, How to use callbacks, Hyperparameter scheduling, Tips & tricks, Training an image classifier, Training loop, imagenette_demo