Optimisers

Optimisers

Consider a simple linear regression. We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters W and b.

using Flux, Flux.Tracker

W = param(rand(2, 5))
b = param(rand(2))

predict(x) = W*x .+ b
loss(x, y) = sum((predict(x) .- y).^2)

x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3

θ = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), θ)

We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:

using Flux.Tracker: grad, update!

η = 0.1 # Learning Rate
for p in (W, b)
  update!(p, -η * grads[p])
end

Running this will alter the parameters W and b and our loss should go down. Flux provides a more general way to do optimiser updates like this.

opt = Descent(0.1) # Gradient descent with learning rate 0.1

for p in (W, b)
  update!(opt, p, grads[p])
end

An optimiser update! accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass opt to our training loop, which will update all parameters of the model in a loop. However, we can now easily replace Descent with a more advanced optimiser such as ADAM.

Optimiser Reference

All optimisers return an object that, when passed to train!, will update the parameters passed to it.

Descent(η)

Classic gradient descent optimiser with learning rate η. For each parameter p and its gradient δp, this runs p -= η*δp.

source
Momentum(params, η = 0.01; ρ = 0.9)

Gradient descent with learning rate η and momentum ρ.

source
Nesterov(eta, ρ = 0.9)

Gradient descent with learning rate η and Nesterov momentum ρ.

source
RMSProp(η = 0.001, ρ = 0.9)

RMSProp optimiser. Parameters other than learning rate don't need tuning. Often a good choice for recurrent networks.

source
ADAM(η = 0.001, β = (0.9, 0.999))

ADAM optimiser.

source
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)

AdaMax optimiser. Variant of ADAM based on the ∞-norm.

source
ADAGrad(η = 0.1; ϵ = 1e-8)

ADAGrad optimiser. Parameters don't need tuning.

source
ADADelta(ρ = 0.9, ϵ = 1e-8)

ADADelta optimiser. Parameters don't need tuning.

source
AMSGrad(η = 0.001, β = (0.9, 0.999))

AMSGrad optimiser. Parameters don't need tuning.

source
NADAM(η = 0.001, β = (0.9, 0.999))

NADAM optimiser. Parameters don't need tuning.

source
Flux.Optimise.ADAMWFunction.
ADAMW((η = 0.001, β = (0.9, 0.999), decay = 0)

ADAMW fixing weight decay regularization in Adam.

source