Learner
struct
defined in module
FluxTraining
Learner(model, lossfn; [callbacks = [], optimizer = ADAM(), kwargs...])
Holds and coordinates all state of the training.
model
is trained by optimizing
lossfn
with
optimizer
on
data
.
Positional arguments:
model
: A Flux.jl model or a
NamedTuple
of models.
lossfn
: Loss function with signature
lossfn(model(x), y) -> Number
.
Keyword arguments (optional):
data = ()
: Data iterators. A 2-tuple will be treated as
(trainingdataiter, validdataiter)
. You can also pass in an empty tuple
()
and use the
epoch!
method with a
dataiter
as third argument.
A data iterator is an iterable over batches. For regular supervised training, each batch should be a tuple
(xs, ys)
.
optimizer = ADAM()
: The optimizer used to update the
model
's weights
callbacks = []
: A list of callbacks that should be used. If
usedefaultcallbacks == true
, this will be extended by the default callbacks
usedefaultcallbacks = true
: Whether to add some basic callbacks. Included are
Metrics
,
Recorder
,
ProgressPrinter
,
StopOnNaNLoss
, and
MetricsPrinter
.
cbrunner = LinearRunner()
: Callback runner to use.
(Use this as a reference when implementing callbacks)
model
,
optimizer
, and
lossfn
are stored as passed in
data
is a
PropDict
of data iterators, usually
:training
and
:validation
.
params
: An instance of
model
's parameters of type
Flux.Params
. If
model
is a
NamedTuple
, then
params
is a
NamedTuple
as well.
step::
PropDict
: State of the last step. Contents depend on the last run
Phase
.
cbstate::
PropDict
: Special state container that callbacks can save state to for other callbacks. Its keys depend on what callbacks are being used. See the
custom callbacks guide for more info.
Learner(model, data, optimizer, lossfn, [callbacks...; kwargs...])
There are
3
methods for FluxTraining.Learner
:
The following pages link back here:
Custom learning tasks, Image segmentation, Introduction, Keypoint regression, Siamese image similarity, Tabular Classification, TimeSeries Classification, Variational autoencoders, fastai API comparison, tsregression
FastAI.jl , interpretation/learner.jl , learner.jl , training/onecycle.jl , FluxTraining.jl , learner.jl , testutils.jl