YVIq2
function
epoch!
(
learner
,
phase
::
Phase
,
dataiter
=
learner
.
data
[
phasedataiter
(
phase
)
]
)
runepoch
(
learner
,
phase
)
do
_
for
batch
in
dataiter
step!
(
learner
,
phase
,
batch
)
end
end
end
function
step!
end
function
step!
(
learner
,
phase
::
TrainingPhase
,
batch
)
xs
,
ys
=
batch
runstep
(
learner
,
phase
,
(
;
xs
=
xs
,
ys
=
ys
)
)
do
handle
,
state
state
.
grads
=
_gradient
(
learner
.
optimizer
,
learner
.
model
,
learner
.
params
)
do
model
state
.
ŷs
=
model
(
state
.
xs
)
handle
(
LossBegin
(
)
)
state
.
loss
=
learner
.
lossfn
(
state
.
ŷs
,
state
.
ys
)
handle
(
BackwardBegin
(
)
)
return
state
.
loss
end
handle
(
BackwardEnd
(
)
)
learner
.
params
,
learner
.
model
=
_update!
(
learner
.
optimizer
,
learner
.
params
,
learner
.
model
,
state
.
grads
)
end
end
Handle both old Flux.jl and new Optimisers.jl optimisers
_gradient
(
f
,
_
,
m
,
_
)
=
gradient
(
f
,
m
)
[
1
]
_gradient
(
f
,
::
Flux
.
Optimise
.
AbstractOptimiser
,
m
,
ps
::
Params
)
=
gradient
(
(
)
->
f
(
m
)
,
ps
)
function
_update!
(
optimizer
::
Flux
.
Optimise
.
AbstractOptimiser
,
params
,
model
,
grads
)
update!
(
optimizer
,
params
,
grads
)
return
params
,
model
end
function
_update!
(
_
,
st
,
model
,
grads
)
st
,
model
=
Optimisers
.
update!
(
st
,
model
,
grads
)
return
st
,
model
end
function
step!
(
learner
,
phase
::
ValidationPhase
,
batch
)
xs
,
ys
=
batch
runstep
(
learner
,
phase
,
(
;
xs
=
xs
,
ys
=
ys
)
)
do
_
,
state
state
.
ŷs
=
learner
.
model
(
state
.
xs
)
state
.
loss
=
learner
.
lossfn
(
state
.
ŷs
,
state
.
ys
)
end
end
function
runepoch
(
epochfn
,
learner
,
phase
::
Phase
)
handlefn
(
e
)
=
handle
(
learner
.
callbacks
.
runner
,
e
,
phase
,
learner
)
try
handlefn
(
EpochBegin
(
)
)
epochfn
(
handlefn
)
handlefn
(
EpochEnd
(
)
)
catch
e
if
e
isa
CancelEpochException
@
debug
"
Epoch skipped
"
error
=
e
handlefn
(
EpochEnd
(
)
)
else
rethrow
(
)
end
end
end
function
runstep
(
stepfn
,
learner
,
phase
::
Phase
,
initialstate
=
(
;
)
)
state
=
PropDict
(
pairs
(
initialstate
)
)
handlefn
(
e
)
=
handle
(
learner
.
callbacks
.
runner
,
e
,
phase
,
learner
)
try
learner
.
step
=
state
handlefn
(
StepBegin
(
)
)
stepfn
(
handlefn
,
state
)
handlefn
(
StepEnd
(
)
)
return
state
catch
e
if
e
isa
CancelStepException
@
debug
"
Step skipped
"
error
=
e
else
rethrow
(
)
end
end
return
state
end
Utilities
function
fit!
(
learner
,
nepochs
::
Int
,
(
trainiter
,
validiter
)
)
for
i
in
1
:
nepochs
epoch!
(
learner
,
TrainingPhase
(
)
,
trainiter
)
epoch!
(
learner
,
ValidationPhase
(
)
,
validiter
)
end
end
function
fit!
(
learner
,
nepochs
::
Int
)
fit!
(
learner
,
nepochs
,
(
learner
.
data
.
training
,
learner
.
data
.
validation
)
)
end