Learner

struct defined in module FluxTraining


			Learner(model, lossfn; [callbacks = [], optimizer = ADAM(), kwargs...])

Holds and coordinates all state of the training. model is trained by optimizing lossfn with optimizer on data.

Arguments

Positional arguments:

  • model: A Flux.jl model or a NamedTuple of models.

  • lossfn: Loss function with signature lossfn(model(x), y) -> Number.

Keyword arguments (optional):

  • data = (): Data iterators. A 2-tuple will be treated as (trainingdataiter, validdataiter). You can also pass in an empty tuple () and use the epoch! method with a dataiter as third argument.

    A data iterator is an iterable over batches. For regular supervised training, each batch should be a tuple (xs, ys).

  • optimizer = ADAM(): The optimizer used to update the model's weights

  • callbacks = []: A list of callbacks that should be used. If usedefaultcallbacks == true, this will be extended by the default callbacks

  • usedefaultcallbacks = true: Whether to add some basic callbacks. Included are Metrics, Recorder, ProgressPrinter, StopOnNaNLoss, and MetricsPrinter.

  • cbrunner = LinearRunner(): Callback runner to use.

Fields

(Use this as a reference when implementing callbacks)

  • model, optimizer, and lossfn are stored as passed in

  • data is a PropDict of data iterators, usually :training and :validation.

  • params: An instance of model's parameters of type Flux.Params. If model is a NamedTuple, then params is a NamedTuple as well.

  • step:: PropDict: State of the last step. Contents depend on the last run Phase .

  • cbstate:: PropDict: Special state container that callbacks can save state to for other callbacks. Its keys depend on what callbacks are being used. See the custom callbacks guide for more info.


			Learner(model, data, optimizer, lossfn, [callbacks...; kwargs...])