YVIq2
function
setcallbacks!
(
learner
,
callbacks
)
learner
.
callbacks
=
Callbacks
(
callbacks
)
end
function
addcallback!
(
learner
,
callback
::
AbstractCallback
)
learner
.
callbacks
=
Callbacks
(
vcat
(
learner
.
callbacks
.
cbs
,
callback
)
)
init!
(
callback
,
learner
)
return
end
function
getcallback
(
learner
,
C
::
Type
{
<:
FluxTraining
.
Callback
}
)
cbidx
=
findfirst
(
isa
.
(
learner
.
callbacks
.
cbs
,
C
)
)
return
isnothing
(
cbidx
)
?
nothing
:
learner
.
callbacks
.
cbs
[
cbidx
]
end
function
replacecallback!
(
learner
,
callback
::
C
)
where
{
C
<:
FluxTraining
.
Callback
}
cbidx
=
findfirst
(
isa
.
(
learner
.
callbacks
.
cbs
,
C
)
)
if
isnothing
(
cbidx
)
FluxTraining
.
addcallback!
(
learner
,
callback
)
return
nothing
else
oldcb
=
learner
.
callbacks
.
cbs
[
cbidx
]
learner
.
callbacks
.
cbs
[
cbidx
]
=
callback
FluxTraining
.
setcallbacks!
(
learner
,
learner
.
callbacks
.
cbs
)
return
oldcb
end
end
function
removecallback!
(
learner
,
C
::
Type
{
<:
FluxTraining
.
Callback
}
)
cbidx
=
findfirst
(
isa
.
(
learner
.
callbacks
.
cbs
,
C
)
)
if
isnothing
(
cbidx
)
return
nothing
end
cb
=
popat!
(
learner
.
callbacks
.
cbs
,
cbidx
)
learner
.
callbacks
=
Callbacks
(
learner
.
callbacks
.
cbs
)
return
cb
end