YVIq2
mutable
struct
Callbacks
cbs
::
Vector
runner
::
CallbackRunner
graph
::
SimpleDiGraph
initialized
::
Bool
end
Callbacks
(
cbs
,
runner
=
LinearRunner
(
)
)
=
Callbacks
(
cbs
,
runner
,
callbackgraph
(
cbs
)
,
false
)
init!
(
cbs
::
Callbacks
,
learner
)
=
foreach
(
cb
->
init!
(
cb
,
learner
)
,
cbs
.
cbs
)
mutable
struct
Learner
model
data
::
PropDict
optimizer
lossfn
# this used to store `Flux.Params` but now stores the optimiser state
# if an optim from Optimisers.jl is used
params
step
::
PropDict
callbacks
::
Callbacks
cbstate
::
PropDict
end
function
Learner
(
model
,
lossfn
;
callbacks
=
[
]
,
data
=
(
)
,
optimizer
=
ADAM
(
)
,
kwargs
...
)
return
Learner
(
model
,
data
,
optimizer
,
lossfn
,
callbacks
...
;
kwargs
...
)
end
function
Learner
(
model
,
data
,
optimizer
,
lossfn
,
callbacks
::
Vararg
{
<:
Callback
}
;
usedefaultcallbacks
=
true
,
cbrunner
=
LinearRunner
(
)
)
callbacks
=
collect
(
Callback
,
callbacks
)
if
usedefaultcallbacks
for
cb
in
defaultcallbacks
(
)
if
!
any
(
typeof
(
cb
)
.==
typeof
.
(
callbacks
)
)
push!
(
callbacks
,
cb
)
end
end
end
cbs
=
Callbacks
(
callbacks
,
cbrunner
)
learner
=
Learner
(
model
,
_dataiters
(
data
)
,
optimizer
,
lossfn
,
setupoptimstate
(
model
,
optimizer
)
,
PropDict
(
)
,
cbs
,
PropDict
(
)
)
init!
(
cbs
,
learner
)
return
learner
end
Base
.
show
(
io
::
IO
,
learner
::
Learner
)
=
print
(
io
,
"
Learner()
"
)
defaultcallbacks
(
)
::
Vector
{
AbstractCallback
}
=
[
ProgressPrinter
(
)
,
MetricsPrinter
(
)
,
StopOnNaNLoss
(
)
,
Recorder
(
)
,
Metrics
(
)
,
]
Callback handling
handle
(
event
,
learner
,
phase
)
=
handle
(
learner
.
callbacks
.
runner
,
event
,
phase
,
learner
)
Other
phasedataiter
(
::
AbstractTrainingPhase
)
=
:
training
phasedataiter
(
::
AbstractValidationPhase
)
=
:
validation
function
model!
(
learner
,
model
)
learner
.
model
=
model
learner
.
params
=
setupoptimstate
(
model
,
learner
.
optimizer
)
end
Flux.jl optimisers store
params
, while Optimisers.jl store the result of
setup
setupoptimstate
(
model
,
::
Flux
.
Optimise
.
AbstractOptimiser
)
=
Flux
.
params
(
model
)
Optimisers.jl has no abstract supertype so we assume non-Flux optimisers conform to the Optimisers.jl interface.
setupoptimstate
(
model
,
optim
)
=
Optimisers
.
setup
(
optim
,
model
)
_dataiters
(
d
::
PropDict
)
=
d
_dataiters
(
t
::
NamedTuple
)
=
PropDict
(
pairs
(
t
)
)
function
_dataiters
(
t
::
Tuple
)
if
length
(
t
)
==
0
return
PropDict
(
Dict
{
Symbol
,
Any
}
(
)
)
elseif
length
(
t
)
==
1
return
_dataiters
(
(
training
=
t
[
1
]
)
)
elseif
length
(
t
)
==
2
return
_dataiters
(
(
training
=
t
[
1
]
,
validation
=
t
[
2
]
)
)
else
error
(
"
Please pass a `NamedTuple` or `PropDict` as `data`.
"
)
end
end