Automatic Differentiation using Enzyme.jl
Enzyme.jl is a new package for automatic differentiation. Like Zygote.jl, calling gradient(f, x)
causes it to hooks into the compiler and transform code that is executed while calculating f(x)
, in order to produce code for ∂f/∂x
. But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR) which you can read about here]. It needs far fewer custom rules than Zygote/ChainRules, and in particular is able to support mutation of arrays.
Flux now builds in support for this, using Enzyme's own Duplicated
type. Calling Duplicated
on any Flux model which was defined using @layer
will allocate space for the gradient, and passing that to gradient
(or withgradient
, or train!
) will then use Enzyme instead of Zygote. The gradient functions still return the gradient as usual, which can then be passed to update!
:
julia> using Flux, Enzyme
julia> model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax); # from model zoo
julia> dup_model = Enzyme.Duplicated(model) # this allocates space for the gradient
Duplicated(
Chain(
Dense(784 => 32, σ), # 25_120 parameters
Dense(32 => 10), # 330 parameters
NNlib.softmax,
),
# norm(∇) ≈ 0.0f0
) # Total: 4 arrays, 25_450 parameters, 199.391 KiB.
julia> x1 = randn32(28*28, 1); # fake image
julia> y1 = [i==3 for i in 0:9]; # fake label
julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1)) # uses Enzyme
((layers = ((weight = Float32[-0.010354728 0.032972857 …
-0.0014538406], σ = nothing), nothing),), nothing, nothing)
The gradient returned here is also stored within dup_model
. Both share the same arrays – what is returned is not a copy, just a view of the same memory (wrapped in NamedTuple
s instead of struct
s). They will all be set to zero when you call gradient
again, then replaced with the new values. Alternatively, gradient(f, args...; zero=false)
will add the new gradient to what's already stored.
Writing Const(x1)
is optional, just plain x1
is implicitly constant. Any set of Duplicated
and Const
arguments may appear in any order, so long as there is at least one Duplicated
.
The gradient grads_f[1]
can be passed to update!
as usual. But for convenience, you may also use what is stored within Duplicated
. These are equivalent ways to perform an update step:
julia> opt_state = Flux.setup(Adam(), model)
julia> ans == Flux.setup(Adam(), dup_model)
julia> Flux.update!(opt_state, model, grads_f[1]) # exactly as for Zygote gradients
julia> Flux.update!(opt_state, dup_model) # equivlent new path, Enzyme only
Instead of using these FLux functions, you can also use Enzyme's own functions directly. Enzyme.gradient
works like this:
julia> grads_e = Enzyme.gradient(Reverse, (m,x,y) -> sum(abs2, m(x) .- y), model, Const(x1), Const(y1))
(Chain(Dense(784 => 32, σ), Dense(32 => 10), softmax), nothing, nothing)
julia> grads_f[1].layers[2].bias ≈ grads_e[1].layers[2].bias
true
Note that what Enzyme.gradient
returns is an object like deepcopy(model)
of the same type, grads_e[1] isa Chain
. But its fields contain the same gradient.
There is also a method of train!
which similarly takes Duplicated(model)
:
julia> opt_state = Flux.setup(Adam(0), model);
julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state)
Second-order AD
If you calculate a gradient within the loss function, then training will involve 2nd derivatives. While this is in principle supported by Zygote.jl, there are many bugs, and Enzyme.jl is probably a better choice.
Listing
Flux.gradient
— Methodgradient(f, args::Union{Const,Duplicated}...)
This should return the same answer as gradient(f, args...)
, but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
Only available when Enzyme is loaded!
This method is used when at least one argument is of type Duplicated
, and all unspecified aguments are wrapped in Const
. Note that Enzyme's Active
is not supported.
Besides returning the gradient, this is also stored within the Duplicated
object. Calling Enzyme.Duplicated(model)
allocates space for the gradient, which is zero'd befor use when calling gradient
. With the keyword zero=false
, the new gradient will instead be added to what is already stored.
Enzyme support like this is new and somewhat experimental. This method was added in Flux 0.15.
Example
julia> using Flux
julia> model = Chain(Dense([3.0;;]));
julia> Flux.gradient(model, [1]) do m, x # computed using Zygote
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0])
julia> using Enzyme
julia> dup_model = Duplicated(model); # allocates space for gradient
julia> Flux.gradient(dup_model, Const([1])) do m, x # Enzyme, returns the same
sum(abs2, m(x))
end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing)
julia> dup_model # same gradient is also stored within Duplicated
Duplicated(
Chain(
Dense(1 => 1), # 2 parameters
),
# norm(∇) ≈ 8.49
)
julia> Flux.destructure((weight = [6.0;;], bias = [6.0]))[1] |> norm
8.48528137423857
julia> Flux.gradient(dup_model, [1]; zero=false) do m, x # implict Const([1]), and grad accumulation
sum(abs2, m(x))
end
((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing)
Flux.withgradient
— Methodwithgradient(f, args::Union{Const,Duplicated}...)
This should return the same answer as withgradient(f, model, args...)
, but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
Only available when Enzyme is loaded!
Enzyme support like this is new and somewhat experimental. This method was added in Flux 0.15.
Example
julia> using Flux, Enzyme
julia> model = Chain(Embedding([1.1 2.2 3.3]), Dense([4.4;;]), only);
julia> model(3)
14.52
julia> Flux.withgradient(m -> m(3), model) # this uses Zygote
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
julia> Flux.withgradient(m -> m(3), Duplicated(model)) # this uses Enzyme
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
The function f
may return Tuple or NamedTuple, with the loss as the first element. The gradient is then grad = gradient(first∘f, args...)
but the returned value is val = f(args...)
:
julia> Flux.withgradient(m -> (m(3), "aux"), Duplicated(model))
(val = (14.52, "aux"), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
julia> Flux.withgradient(m -> (loss=m(3), aux=round.(m.(1:3); digits=3)), Duplicated(model))
(val = (loss = 14.52, aux = [4.84, 9.68, 14.52]), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
Flux.Train.train!
— Methodtrain!(loss, Duplicated(model), data, opt_state)
This method uses Enzyme.jl instead of Zygote.jl to compute the gradients, but is otherwise the same as train!(loss, model, data, opt_state)
.
Only available when Enzyme is loaded.
This method was added in Flux 0.13.9.
Enzyme.jl has its own extensive documentation.