Flux
module
Train
using
LinearAlgebra
using
Optimisers
:
Optimisers
using
Functors
:
fmap
,
fmapstructure
using
.
.
Flux
:
Flux
# used only in docstring
import
.
.
Flux
.
Optimise
:
train!
,
update!
# during 0.13, we add methods to the old functions
export
setup
,
train!
using
ProgressLogging
:
@
progress
,
@
withprogress
,
@
logprogress
using
Zygote
:
Zygote
,
Params
"""
opt_state = setup(rule, model)
This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!).
It differs from `Optimisers.setup` in that it:
* has one extra check for mutability (since Flux expects to mutate the model in-place,
while Optimisers.jl is designed to return an updated model)
* has methods which accept Flux's old optimisers, and convert them.
(The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.)
!!! compat "New"
This function was added in Flux 0.13.9. It was not used by the old "implicit"
interface, using `Flux.Optimise` module and [`Flux.params`](@ref).
# Example
```jldoctest
julia> model = Dense(2=>1, leakyrelu; init=ones);
julia> opt_state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [0.0]), σ = ())
julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:
julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y
sum(abs.(m(x) .- y)) * 100
end
julia> model.bias # was zero, mutated by Flux.train!
1-element Vector{Float64}:
10.19
julia> opt_state # mutated by Flux.train!
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [-10.09]), σ = ())
```
"""
function
setup
(
rule
::
Optimisers
.
AbstractRule
,
model
)
state
=
Optimisers
.
setup
(
rule
,
model
)
# This check only needs foreach; using fmap caused https://github.com/FluxML/Flux.jl/issues/2144
fmapstructure
(
model
,
exclude
=
Optimisers
.
isnumeric
)
do
x
Optimisers
.
maywrite
(
x
)
||
error
(
"""
model must be fully mutable for `train!` to work, got `x::
$
(
typeof
(
x
)
)
`.
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::
$
(
typeof
(
x
)
)
) = true`
"""
)
end
state
end
"""
train!(loss, model, data, opt_state)
Uses a `loss` function and training `data` to improve the `model`'s parameters
according to a particular optimisation rule encoded in `opt_state`.
Iterates through `data` once, evaluating for each `d in data` either
`loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`.
For example, with these definitions...
```
data = [(x1, y1), (x2, y2), (x3, y3)]
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
opt_state = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
```
...calling `Flux.train!(loss3, model, data, opt_state)` runs a loop much like this:
```
for d in data
∂L∂m = gradient(loss3, model, d...)[1]
update!(opt_state, model, ∂L∂m)
end
```
You can also write this loop yourself, if you need more flexibility.
For this reason `train!` is not highly extensible.
It adds only a few features to the loop above:
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
* Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).
!!! compat "New"
This method was added in Flux 0.13.9.
It has significant changes from the one used by Flux ≤ 0.13:
* It now takes the `model` itself, not the result of [`Flux.params`](@ref).
(This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
* Instead of `loss` being a function which accepts only the data,
now it must also accept the `model` itself, as the first argument.
* `opt_state` should be the result of [`Flux.setup`](@ref). Using an optimiser
such as `Adam()` without this step should give you a warning.
* Callback functions are not supported.
(But any code can be included in the above `for` loop.)
"""
function
train!
(
loss
,
model
,
data
,
opt
;
cb
=
nothing
)
isnothing
(
cb
)
||
error
(
"""
train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.
"""
)
@
withprogress
for
(
i
,
d
)
in
enumerate
(
data
)
d_splat
=
d
isa
Tuple
?
d
:
(
d
,
)
l
,
gs
=
Zygote
.
withgradient
(
m
->
loss
(
m
,
d_splat
...
)
,
model
)
if
!
isfinite
(
l
)
throw
(
DomainError
(
"
Loss is
$
l
on data item
$
i
, stopping training
"
)
)
end
opt
,
model
=
Optimisers
.
update!
(
opt
,
model
,
gs
[
1
]
)
@
logprogress
Base
.
haslength
(
data
)
?
i
/
length
(
data
)
:
nothing
end
end
# This method let you use Optimisers.Descent() without setup, when there is no state
function
train!
(
loss
,
model
,
data
,
rule
::
Optimisers
.
AbstractRule
;
cb
=
nothing
)
train!
(
loss
,
model
,
data
,
_rule_to_state
(
model
,
rule
)
;
cb
)
end
function
_rule_to_state
(
model
,
rule
::
Optimisers
.
AbstractRule
)
state
=
setup
(
rule
,
model
)
@
gensym
warn_id
name
=
typeof
(
rule
)
.
name
.
name
fmap
(
state
,
exclude
=
x
->
x
isa
Optimisers
.
Leaf
)
do
leaf
leaf
.
state
isa
Nothing
||
@
warn
"""
Optimiser
$
name
has state which will be discarded after `train!` finishes.
Please run `opt = Flux.setup(
$
name
(), model)` and pass this `opt` to `train!`.
"""
leaf
maxlog
=
1
_id
=
warn_id
leaf
end
state
end
end
module Train