Learner

struct defined in module FluxTraining


			Learner(model, lossfn; [callbacks = [], optimizer = ADAM(), kwargs...])

Holds and coordinates all state of the training. model is trained by
optimizing lossfn with optimizer on data .

Arguments

Positional arguments:

  • model : A Flux.jl model or a NamedTuple of models.

  • lossfn : Loss function with signature lossfn(model(x), y) -> Number .

Keyword arguments (optional):

  • data = () : Data iterators. A 2-tuple will be treated as (trainingdataiter, validdataiter) .
    You can also pass in an empty tuple () and use the epoch! method with a
    dataiter as third argument.

    A data iterator is an iterable over batches. For regular supervised training,
    each batch should be a tuple (xs, ys) .

  • optimizer = ADAM() : The optimizer used to update the model ’s weights

  • callbacks = [] : A list of callbacks that should be used. If usedefaultcallbacks == true ,
    this will be extended by the default callbacks

  • usedefaultcallbacks = true : Whether to add some basic callbacks. Included
    are Metrics , Recorder , ProgressPrinter ,
    StopOnNaNLoss , and MetricsPrinter .

  • cbrunner = LinearRunner() : Callback runner to use.

Fields

(Use this as a reference when implementing callbacks)

  • model , optimizer , and lossfn are stored as passed in

  • data is a PropDict of data iterators, usually :training and :validation .

  • params : An instance of model ’s parameters of type Flux.Params . If model is
    a NamedTuple , then params is a NamedTuple as well.

  • step:: PropDict : State of the last step. Contents depend on the last run
    Phase .

  • cbstate:: PropDict : Special state container that callbacks can
    save state to for other callbacks. Its keys depend on what callbacks
    are being used. See the custom callbacks guide
    for more info.



			Learner(model, data, optimizer, lossfn, [callbacks...; kwargs...])
Methods

There are 4 methods for Learner:

Learner(model, data::FluxTraining.PropDict, optimizer, lossfn, params, step::FluxTraining.PropDict, callbacks::FluxTraining.Callbacks, cbstate::FluxTraining.PropDict)
learner.jl:15
Learner(model, lossfn; callbacks, data, optimizer, kwargs...)
learner.jl:72
Learner(model, data, optimizer, lossfn, callbacks::FluxTraining.SafeCallback...; usedefaultcallbacks, cbrunner)
learner.jl:80
Learner(model, data, optimizer, lossfn, params, step, callbacks, cbstate)
learner.jl:15