Optimisers
Consider a simple linear regression. We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters W
and b
.
using Flux, Flux.Tracker
W = param(rand(2, 5))
b = param(rand(2))
predict(x) = W*x .+ b
loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3
θ = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), θ)
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
using Flux.Tracker: grad, update!
η = 0.1 # Learning Rate
for p in (W, b)
update!(p, -η * grads[p])
end
Running this will alter the parameters W
and b
and our loss should go down. Flux provides a more general way to do optimiser updates like this.
opt = Descent(0.1) # Gradient descent with learning rate 0.1
for p in (W, b)
update!(opt, p, grads[p])
end
An optimiser update!
accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass opt
to our training loop, which will update all parameters of the model in a loop. However, we can now easily replace Descent
with a more advanced optimiser such as ADAM
.
Optimiser Reference
All optimisers return an object that, when passed to train!
, will update the parameters passed to it.
Flux.Optimise.Descent
— Type.Descent(η)
Classic gradient descent optimiser with learning rate η
. For each parameter p
and its gradient δp
, this runs p -= η*δp
.
Flux.Optimise.Momentum
— Type.Momentum(params, η = 0.01; ρ = 0.9)
Gradient descent with learning rate η
and momentum ρ
.
Flux.Optimise.Nesterov
— Type.Nesterov(eta, ρ = 0.9)
Gradient descent with learning rate η
and Nesterov momentum ρ
.
Flux.Optimise.RMSProp
— Type.RMSProp(η = 0.001, ρ = 0.9)
RMSProp optimiser. Parameters other than learning rate don't need tuning. Often a good choice for recurrent networks.
Flux.Optimise.ADAM
— Type.ADAM(η = 0.001, β = (0.9, 0.999))
ADAM optimiser.
Flux.Optimise.AdaMax
— Type.AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
AdaMax optimiser. Variant of ADAM based on the ∞-norm.
Flux.Optimise.ADAGrad
— Type.ADAGrad(η = 0.1; ϵ = 1e-8)
ADAGrad optimiser. Parameters don't need tuning.
Flux.Optimise.ADADelta
— Type.ADADelta(ρ = 0.9, ϵ = 1e-8)
ADADelta optimiser. Parameters don't need tuning.
Flux.Optimise.AMSGrad
— Type.AMSGrad(η = 0.001, β = (0.9, 0.999))
AMSGrad optimiser. Parameters don't need tuning.
Flux.Optimise.NADAM
— Type.NADAM(η = 0.001, β = (0.9, 0.999))
NADAM optimiser. Parameters don't need tuning.
Flux.Optimise.ADAMW
— Function.ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)
ADAMW fixing weight decay regularization in Adam.