How Flux Works: Gradients and Layers
Taking Gradients
Flux's core feature is taking gradients of Julia code. The gradient
function takes another Julia function f
and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
julia> using Flux
julia> f(x) = 3x^2 + 2x + 1;
julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2
julia> df(2)
14.0
julia> d2f(x) = gradient(df, x)[1]; # dยฒf/dxยฒ = 6
julia> d2f(2)
6.0
When a function has many parameters, we can get gradients of each one at the same time:
julia> f(x, y) = sum((x .- y).^2);
julia> gradient(f, [2, 1], [2, 0])
([0.0, 2.0], [-0.0, -2.0])
These gradients are based on x
and y
. Flux works by instead taking gradients based on the weights and biases that make up the parameters of a model.
Machine learning often can have hundreds of parameters, so Flux lets you work with collections of parameters, via the params
functions. You can get the gradient of all parameters used in a program without explicitly passing them in.
julia> x = [2, 1];
julia> y = [2, 0];
julia> gs = gradient(Flux.params(x, y)) do
f(x, y)
end
Grads(...)
julia> gs[x]
2-element Vector{Float64}:
0.0
2.0
julia> gs[y]
2-element Vector{Float64}:
-0.0
-2.0
Here, gradient
takes a zero-argument function; no arguments are necessary because the params
tell it what to differentiate.
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
Building Simple Models
Consider a simple linear regression, which tries to predict an output array y
from an input x
.
W = rand(2, 5)
b = rand(2)
predict(x) = W*x .+ b
function loss(x, y)
ลท = predict(x)
sum((y .- ลท).^2)
end
x, y = rand(5), rand(2) # Dummy data
loss(x, y) # ~ 3
To improve the prediction we can take the gradients of the loss with respect to W
and b
and perform gradient descent.
using Flux
gs = gradient(() -> loss(x, y), Flux.params(W, b))
Now that we have gradients, we can pull them out and update W
to train the model.
Wฬ = gs[W]
W .-= 0.1 .* Wฬ
loss(x, y) # ~ 2.5
The loss has decreased a little, meaning that our prediction x
is closer to the target y
. If we have some data we can already try training the model.
All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can look very different โ they might have millions of parameters or complex control flow. Let's see how Flux handles more complex models.
Building Layers
It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like sigmoid (ฯ
) in between them. In the above style we could write this as:
using Flux
W1 = rand(3, 5)
b1 = rand(3)
layer1(x) = W1 * x .+ b1
W2 = rand(2, 3)
b2 = rand(2)
layer2(x) = W2 * x .+ b2
model(x) = layer2(ฯ.(layer1(x)))
model(rand(5)) # => 2-element vector
This works but is fairly unwieldy, with a lot of repetition โ especially as we add more layers. One way to factor this out is to create a function that returns linear layers.
function linear(in, out)
W = randn(out, in)
b = randn(out)
x -> W * x .+ b
end
linear1 = linear(5, 3) # we can access linear1.W etc
linear2 = linear(3, 2)
model(x) = linear2(ฯ.(linear1(x)))
model(rand(5)) # => 2-element vector
Another (equivalent) way is to create a struct that explicitly represents the affine layer.
struct Affine
W
b
end
Affine(in::Integer, out::Integer) =
Affine(randn(out, in), randn(out))
# Overload call, so the object can be used as a function
(m::Affine)(x) = m.W * x .+ m.b
a = Affine(10, 5)
a(rand(10)) # => 5-element vector
Congratulations! You just built the Dense
layer that comes with Flux. Flux has many interesting layers available, but they're all things you could have built yourself very easily.
(There is one small difference with Dense
โ for convenience it also takes an activation function, like Dense(10 => 5, ฯ)
.)
Stacking It Up
It's pretty common to write models that look something like:
layer1 = Dense(10 => 5, ฯ)
# ...
model(x) = layer3(layer2(layer1(x)))
For long chains, it might be a bit more intuitive to have a list of layers, like this:
using Flux
layers = [Dense(10 => 5, ฯ), Dense(5 => 2), softmax]
model(x) = foldl((x, m) -> m(x), layers, init = x)
model(rand(10)) # => 2-element vector
Handily, this is also provided for in Flux:
model2 = Chain(
Dense(10 => 5, ฯ),
Dense(5 => 2),
softmax)
model2(rand(10)) # => 2-element vector
This quickly starts to look like a high-level deep learning library; yet you can see how it falls out of simple abstractions, and we lose none of the power of Julia code.
A nice property of this approach is that because "models" are just functions (possibly with trainable parameters), you can also see this as simple function composition.
m = Dense(5 => 2) โ Dense(10 => 5, ฯ)
m(rand(10))
Likewise, Chain
will happily work with any Julia function.
m = Chain(x -> x^2, x -> x+1)
m(5) # => 26
Layer Helpers
There is still one problem with this Affine
layer, that Flux does not know to look inside it. This means that Flux.train!
won't see its parameters, nor will gpu
be able to move them to your GPU. These features are enabled by the @functor
macro:
Flux.@functor Affine
Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the Affine
layer as follows:
function Affine((in, out)::Pair; bias=true, init=Flux.randn32)
W = init(out, in)
b = Flux.create_bias(W, bias, out)
Affine(W, b)
end
Affine(3 => 1, bias=false, init=ones) |> gpu
Functors.@functor
โ Macro@functor T
@functor T (x,)
Adds methods to functor
allowing recursion into objects of type T
, and reconstruction. Assumes that T
has a constructor accepting all of its fields, which is true unless you have provided an inner constructor which does not.
By default all fields of T
are considered children
; this can be restricted be restructed by providing a tuple of field names.
Examples
julia> struct Foo; x; y; end
julia> @functor Foo
julia> Functors.children(Foo(1,2))
(x = 1, y = 2)
julia> _, re = Functors.functor(Foo(1,2));
julia> re((10, 20))
Foo(10, 20)
julia> struct TwoThirds a; b; c; end
julia> @functor TwoThirds (a, c)
julia> ch2, re3 = Functors.functor(TwoThirds(10,20,30));
julia> ch2
(a = 10, c = 30)
julia> re3(("ten", "thirty"))
TwoThirds("ten", 20, "thirty")
julia> fmap(x -> 10x, TwoThirds(Foo(1,2), Foo(3,4), 56))
TwoThirds(Foo(10, 20), Foo(3, 4), 560)