FluxTraining
mutable
struct
Callbacks
cbs
::
Vector
runner
::
CallbackRunner
graph
::
SimpleDiGraph
initialized
::
Bool
end
Callbacks
(
cbs
,
runner
=
LinearRunner
(
)
)
=
Callbacks
(
cbs
,
runner
,
callbackgraph
(
cbs
)
,
false
)
init!
(
cbs
::
Callbacks
,
learner
)
=
foreach
(
cb
->
init!
(
cb
,
learner
)
,
cbs
.
cbs
)
mutable
struct
Learner
model
data
::
PropDict
optimizer
lossfn
# this used to store `Flux.Params` but now stores the optimiser state
# if an optim from Optimisers.jl is used
params
step
::
PropDict
callbacks
::
Callbacks
cbstate
::
PropDict
end
"""
Learner(model, lossfn; [callbacks = [], optimizer = ADAM(), kwargs...])
Holds and coordinates all state of the training. `model` is trained by
optimizing `lossfn` with `optimizer` on `data`.
## Arguments
Positional arguments:
- `model`: A Flux.jl model or a `NamedTuple` of models.
- `lossfn`: Loss function with signature `lossfn(model(x), y) -> Number`.
Keyword arguments (optional):
- `data = ()`: Data iterators. A 2-tuple will be treated as `(trainingdataiter, validdataiter)`.
You can also pass in an empty tuple `()` and use the [`epoch!`](#) method with a
`dataiter` as third argument.
A data iterator is an iterable over batches. For regular supervised training,
each batch should be a tuple `(xs, ys)`.
- `optimizer = ADAM()`: The optimizer used to update the `model`'s weights
- `callbacks = []`: A list of callbacks that should be used. If `usedefaultcallbacks == true`,
this will be extended by the default callbacks
- `usedefaultcallbacks = true`: Whether to add some basic callbacks. Included
are [`Metrics`](#), [`Recorder`](#), [`ProgressPrinter`](#),
[`StopOnNaNLoss`](#), and [`MetricsPrinter`](#).
- `cbrunner = LinearRunner()`: Callback runner to use.
## Fields
*(Use this as a reference when implementing callbacks)*
- `model`, `optimizer`, and `lossfn` are stored as passed in
- `data` is a `PropDict` of data iterators, usually `:training` and `:validation`.
- `params`: An instance of `model`'s parameters of type `Flux.Params`. If `model` is
a `NamedTuple`, then `params` is a `NamedTuple` as well.
- `step::`[`PropDict`](#): State of the last step. Contents depend on the last run
[`Phase`](#).
- `cbstate::`[`PropDict`](#): Special state container that callbacks can
save state to for other callbacks. Its keys depend on what callbacks
are being used. See the [custom callbacks guide](/docs/docs/callbacks/custom.md)
for more info.
"""
function
Learner
(
model
,
lossfn
;
callbacks
=
[
]
,
data
=
(
)
,
optimizer
=
ADAM
(
)
,
kwargs
...
)
return
Learner
(
model
,
data
,
optimizer
,
lossfn
,
callbacks
...
;
kwargs
...
)
end
note that Vararg is covariant, using Vararg produces warning
"""
Learner(model, data, optimizer, lossfn, [callbacks...; kwargs...])
"""
function
Learner
(
model
,
data
,
optimizer
,
lossfn
,
callbacks
::
Vararg
{
Callback
}
;
usedefaultcallbacks
=
true
,
cbrunner
=
LinearRunner
(
)
)
callbacks
=
collect
(
Callback
,
callbacks
)
if
usedefaultcallbacks
for
cb
in
defaultcallbacks
(
)
if
!
any
(
typeof
(
cb
)
.==
typeof
.
(
callbacks
)
)
push!
(
callbacks
,
cb
)
end
end
end
cbs
=
Callbacks
(
callbacks
,
cbrunner
)
learner
=
Learner
(
model
,
_dataiters
(
data
)
,
optimizer
,
lossfn
,
setupoptimstate
(
model
,
optimizer
)
,
PropDict
(
)
,
cbs
,
PropDict
(
)
)
init!
(
cbs
,
learner
)
return
learner
end
Base
.
show
(
io
::
IO
,
learner
::
Learner
)
=
print
(
io
,
"
Learner()
"
)
defaultcallbacks
(
)
::
Vector
{
AbstractCallback
}
=
[
ProgressPrinter
(
)
,
MetricsPrinter
(
)
,
StopOnNaNLoss
(
)
,
Recorder
(
)
,
Metrics
(
)
,
]
Callback handling
handle
(
event
,
learner
,
phase
)
=
handle
(
learner
.
callbacks
.
runner
,
event
,
phase
,
learner
)
Other
phasedataiter
(
::
AbstractTrainingPhase
)
=
:
training
phasedataiter
(
::
AbstractValidationPhase
)
=
:
validation
function
model!
(
learner
,
model
)
learner
.
model
=
model
learner
.
params
=
setupoptimstate
(
model
,
learner
.
optimizer
)
end
Flux.jl optimisers store
params
, while Optimisers.jl store the result of
setup
setupoptimstate
(
model
,
::
Flux
.
Optimise
.
AbstractOptimiser
)
=
Flux
.
params
(
model
)
Optimisers.jl has no abstract supertype so we assume non-Flux optimisers conform to the Optimisers.jl interface.
setupoptimstate
(
model
,
optim
)
=
Optimisers
.
setup
(
optim
,
model
)
_dataiters
(
d
::
PropDict
)
=
d
_dataiters
(
t
::
NamedTuple
)
=
PropDict
(
pairs
(
t
)
)
function
_dataiters
(
t
::
Tuple
)
if
length
(
t
)
==
0
return
PropDict
(
Dict
{
Symbol
,
Any
}
(
)
)
elseif
length
(
t
)
==
1
return
_dataiters
(
(
training
=
t
[
1
]
)
)
elseif
length
(
t
)
==
2
return
_dataiters
(
(
training
=
t
[
1
]
,
validation
=
t
[
2
]
)
)
else
error
(
"
Please pass a `NamedTuple` or `PropDict` as `data`.
"
)
end
end