FluxTraining
ProgressPrinter
"""
ProgressPrinter()
Prints a progress bar of the currently running epoch.
"""
mutable
struct
ProgressPrinter
<:
Callback
p
::
Union
{
Nothing
,
Progress
}
end
ProgressPrinter
(
)
=
ProgressPrinter
(
nothing
)
Base
.
show
(
io
::
IO
,
::
ProgressPrinter
)
=
print
(
io
,
"
ProgressPrinter()
"
)
function
on
(
::
EpochBegin
,
phase
::
Phase
,
cb
::
ProgressPrinter
,
learner
)
e
=
learner
.
cbstate
.
history
[
phase
]
.
epochs
+
1
dataiter
=
get
(
learner
.
data
,
phasedataiter
(
phase
)
,
nothing
)
if
isnothing
(
dataiter
)
cb
.
p
=
nothing
println
(
"
Epoch
$
(
e
)
$
(
phase
)
...
"
)
else
cb
.
p
=
Progress
(
length
(
dataiter
)
,
"
Epoch
$
(
e
)
$
(
phase
)
:
"
)
end
end
on
(
::
StepEnd
,
::
Phase
,
cb
::
ProgressPrinter
,
learner
)
=
isnothing
(
cb
.
p
)
||
next!
(
cb
.
p
)
runafter
(
::
ProgressPrinter
)
=
(
Recorder
,
)
stateaccess
(
::
ProgressPrinter
)
=
(
data
=
Read
(
)
,
cbstate
=
(
history
=
Read
(
)
)
,
)
"""
MetricsPrinter() <: Callback
Callback that prints metrics after every epoch. Relies on the metrics computed by
[`Metrics`](#), so will error if no `Metrics` callback is used.
This callback is added by default to every [`Learner`](#) unless you pass in
`usedefaultcallbacks = false`.
"""
struct
MetricsPrinter
<:
Callback
end
function
on
(
::
EpochEnd
,
phase
::
Phase
,
cb
::
MetricsPrinter
,
learner
)
mvhistory
=
learner
.
cbstate
.
metricsepoch
[
phase
]
epoch
=
learner
.
cbstate
.
history
[
phase
]
.
epochs
print_epoch_table
(
mvhistory
,
epoch
,
phase
)
end
function
print_epoch_table
(
mvhistory
,
epoch
,
phase
)
header
=
vcat
(
[
"
Phase
"
,
"
Epoch
"
]
,
string
.
(
keys
(
mvhistory
)
)
)
vals
=
[
last
(
mvhistory
,
key
)
|>
last
for
key
in
keys
(
mvhistory
)
]
data
=
reshape
(
vcat
(
[
string
(
phase
)
,
epoch
]
,
vals
)
,
1
,
:
)
pretty_table
(
data
;
header
=
header
,
formatters
=
PrettyTables
.
ft_round
(
5
)
)
end
stateaccess
(
::
MetricsPrinter
)
=
(
;
cbstate
=
(
metricsepoch
=
Read
(
)
,
history
=
Read
(
)
)
)
runafter
(
::
MetricsPrinter
)
=
(
Metrics
,
)
StopOnNaNLoss
"""
StopOnNaNLoss()
Stops the training when a NaN loss is encountered.
This callback is added by default to every [`Learner`](#) unless you pass in
`usedefaultcallbacks = false`.
"""
struct
StopOnNaNLoss
<:
Callback
end
function
on
(
::
BackwardEnd
,
::
AbstractTrainingPhase
,
::
StopOnNaNLoss
,
learner
)
!
isnan
(
learner
.
step
.
loss
)
||
throw
(
CancelFittingException
(
"
Encountered NaN loss
"
)
)
end
stateaccess
(
::
StopOnNaNLoss
)
=
(
step
=
(
loss
=
Read
(
)
)
,
)
"""
ToDevice(movedatafn, movemodelfn) <: Callback
Moves model and step data to a device using `movedatafn` for step data
and `movemodelfn` for the model. For example `ToDevice(Flux.gpu, Flux.gpu)`,
moves them to a GPU if available. See [`ToGPU`](#).
By default, only moves `step.xs` and `step.ys`, but this can be extended
to other state by implementing `on(::StepBegin, ::MyCustomPhase, ::ToDevice, learner)`.
"""
struct
ToDevice
<:
Callback
movedatafn
movemodelfn
end
function
on
(
::
EpochBegin
,
::
Phase
,
cb
::
ToDevice
,
learner
)
model!
(
learner
,
cb
.
movemodelfn
(
learner
.
model
)
)
end
stateaccess
(
::
ToDevice
)
=
(
model
=
Write
(
)
,
params
=
Write
(
)
,
step
=
Write
(
)
,
optimizer
=
Read
(
)
,
)
"""
ToGPU()
Callback that moves model and batch data to the GPU during training.
Convenience for [`ToDevice`](#)`(Flux.gpu)`.
"""
ToGPU
(
)
=
ToDevice
(
gpu
,
gpu
)
function
on
(
::
StepBegin
,
::
Phase
,
cb
::
ToDevice
,
learner
)
learner
.
step
.
xs
=
cb
.
movedatafn
(
learner
.
step
.
xs
)
learner
.
step
.
ys
=
cb
.
movedatafn
(
learner
.
step
.
ys
)
end
function
garbagecollect
(
)
GC
.
gc
(
)
if
Base
.
Sys
.
islinux
(
)
ccall
(
:
malloc_trim
,
Cvoid
,
(
Cint
,
)
,
0
)
end
end
"""
GarbageCollect(nsteps)
Every `nsteps` steps, forces garbage collection.
Use this if you get memory leaks from, for example,
parallel data loading.
Performs an additional C-call on Linux systems that can
sometimes help.
"""
function
GarbageCollect
(
nsteps
::
Int
=
100
)
return
throttle
(
CustomCallback
(
(
learner
)
->
garbagecollect
(
)
,
StepEnd
,
Phase
)
,
StepEnd
,
freq
=
nsteps
)
end