Learner
struct defined in module
FluxTraining
Learner(model, lossfn; [callbacks = [], optimizer = ADAM(), kwargs...])
Holds and coordinates all state of the training.
model is trained by optimizing
lossfn with
optimizer on
data.
Positional arguments:
model: A Flux.jl model or a
NamedTuple of models.
lossfn: Loss function with signature
lossfn(model(x), y) -> Number.
Keyword arguments (optional):
data = (): Data iterators. A 2-tuple will be treated as
(trainingdataiter, validdataiter). You can also pass in an empty tuple
() and use the
epoch! method with a
dataiter as third argument.
A data iterator is an iterable over batches. For regular supervised training, each batch should be a tuple
(xs, ys).
optimizer = ADAM(): The optimizer used to update the
model's weights
callbacks = []: A list of callbacks that should be used. If
usedefaultcallbacks == true, this will be extended by the default callbacks
usedefaultcallbacks = true: Whether to add some basic callbacks. Included are
Metrics,
Recorder,
ProgressPrinter,
StopOnNaNLoss, and
MetricsPrinter.
cbrunner = LinearRunner(): Callback runner to use.
(Use this as a reference when implementing callbacks)
model,
optimizer, and
lossfn are stored as passed in
data is a
PropDict of data iterators, usually
:training and
:validation.
params: An instance of
model's parameters of type
Flux.Params. If
model is a
NamedTuple, then
params is a
NamedTuple as well.
step::
PropDict: State of the last step. Contents depend on the last run
Phase
.
cbstate::
PropDict: Special state container that callbacks can save state to for other callbacks. Its keys depend on what callbacks are being used. See the
custom callbacks guide for more info.
Learner(model, data, optimizer, lossfn, [callbacks...; kwargs...])
There are
3
methods for FluxTraining.Learner:
The following pages link back here:
Custom learning tasks, Image segmentation, Introduction, Keypoint regression, Siamese image similarity, Tabular Classification, TimeSeries Classification, Variational autoencoders, fastai API comparison, tsregression
FastAI.jl , interpretation/learner.jl , learner.jl , training/onecycle.jl , FluxTraining.jl , learner.jl , testutils.jl