Automatic Differentiation using Enzyme.jl

Enzyme.jl is a new package for automatic differentiation. Like Zygote.jl, calling gradient(f, x) causes it to hooks into the compiler and transform code that is executed while calculating f(x), in order to produce code for ∂f/∂x. But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR) which you can read about here]. It needs far fewer custom rules than Zygote/ChainRules, and in particular is able to support mutation of arrays.

Flux now builds in support for this, using Enzyme's own Duplicated type. Calling Duplicated on any Flux model which was defined using @layer will allocate space for the gradient, and passing that to gradient (or withgradient, or train!) will then use Enzyme instead of Zygote. The gradient functions still return the gradient as usual, which can then be passed to update!:

julia> using Flux, Enzyme

julia> model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax);  # from model zoo

julia> dup_model = Enzyme.Duplicated(model)  # this allocates space for the gradient
Duplicated(
  Chain(
    Dense(784 => 32, σ),                # 25_120 parameters
    Dense(32 => 10),                    # 330 parameters
    NNlib.softmax,
  ),
  # norm(∇) ≈ 0.0f0
)                   # Total: 4 arrays, 25_450 parameters, 199.391 KiB.

julia> x1 = randn32(28*28, 1);  # fake image

julia> y1 = [i==3 for i in 0:9];  # fake label

julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1))  # uses Enzyme
((layers = ((weight = Float32[-0.010354728 0.032972857 …
    -0.0014538406], σ = nothing), nothing),), nothing, nothing)

The gradient returned here is also stored within dup_model. Both share the same arrays – what is returned is not a copy, just a view of the same memory (wrapped in NamedTuples instead of structs). They will all be set to zero when you call gradient again, then replaced with the new values. Alternatively, gradient(f, args...; zero=false) will add the new gradient to what's already stored.

Writing Const(x1) is optional, just plain x1 is implicitly constant. Any set of Duplicated and Const arguments may appear in any order, so long as there is at least one Duplicated.

The gradient grads_f[1] can be passed to update! as usual. But for convenience, you may also use what is stored within Duplicated. These are equivalent ways to perform an update step:

julia> opt_state = Flux.setup(Adam(), model)

julia> ans == Flux.setup(Adam(), dup_model)

julia> Flux.update!(opt_state, model, grads_f[1])  # exactly as for Zygote gradients

julia> Flux.update!(opt_state, dup_model)  # equivlent new path, Enzyme only

Instead of using these FLux functions, you can also use Enzyme's own functions directly. Enzyme.gradient works like this:

julia> grads_e = Enzyme.gradient(Reverse, (m,x,y) -> sum(abs2, m(x) .- y), model, Const(x1), Const(y1))
(Chain(Dense(784 => 32, σ), Dense(32 => 10), softmax), nothing, nothing)

julia> grads_f[1].layers[2].bias ≈ grads_e[1].layers[2].bias
true

Note that what Enzyme.gradient returns is an object like deepcopy(model) of the same type, grads_e[1] isa Chain. But its fields contain the same gradient.

There is also a method of train! which similarly takes Duplicated(model):

julia> opt_state = Flux.setup(Adam(0), model);

julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state)

Second-order AD

If you calculate a gradient within the loss function, then training will involve 2nd derivatives. While this is in principle supported by Zygote.jl, there are many bugs, and Enzyme.jl is probably a better choice.

Listing

Flux.gradientMethod
gradient(f, args::Union{Const,Duplicated}...)

This should return the same answer as gradient(f, args...), but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.

Only available when Enzyme is loaded!

This method is used when at least one argument is of type Duplicated, and all unspecified aguments are wrapped in Const. Note that Enzyme's Active is not supported.

Besides returning the gradient, this is also stored within the Duplicated object. Calling Enzyme.Duplicated(model) allocates space for the gradient, which is zero'd befor use when calling gradient. With the keyword zero=false, the new gradient will instead be added to what is already stored.

Experimental

Enzyme support like this is new and somewhat experimental. This method was added in Flux 0.15.

Example

julia> using Flux

julia> model = Chain(Dense([3.0;;]));

julia> Flux.gradient(model, [1]) do m, x  # computed using Zygote
         sum(abs2, m(x))
       end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0])

julia> using Enzyme

julia> dup_model = Duplicated(model);  # allocates space for gradient

julia> Flux.gradient(dup_model, Const([1])) do m, x  # Enzyme, returns the same
         sum(abs2, m(x))
       end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing)

julia> dup_model  # same gradient is also stored within Duplicated
Duplicated(
  Chain(
    Dense(1 => 1),                      # 2 parameters
  ),
  # norm(∇) ≈ 8.49
)

julia> Flux.destructure((weight = [6.0;;], bias = [6.0]))[1] |> norm
8.48528137423857

julia> Flux.gradient(dup_model, [1]; zero=false) do m, x  # implict Const([1]), and grad accumulation
         sum(abs2, m(x))
       end
((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing)
source
Flux.withgradientMethod
withgradient(f, args::Union{Const,Duplicated}...)

This should return the same answer as withgradient(f, model, args...), but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.

Only available when Enzyme is loaded!

Experimental

Enzyme support like this is new and somewhat experimental. This method was added in Flux 0.15.

Example

julia> using Flux, Enzyme

julia> model = Chain(Embedding([1.1 2.2 3.3]), Dense([4.4;;]), only);

julia> model(3)
14.52

julia> Flux.withgradient(m -> m(3), model)  # this uses Zygote
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))

julia> Flux.withgradient(m -> m(3), Duplicated(model))  # this uses Enzyme
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))

The function f may return Tuple or NamedTuple, with the loss as the first element. The gradient is then grad = gradient(first∘f, args...) but the returned value is val = f(args...):

julia> Flux.withgradient(m -> (m(3), "aux"), Duplicated(model))
(val = (14.52, "aux"), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))

julia> Flux.withgradient(m -> (loss=m(3), aux=round.(m.(1:3); digits=3)), Duplicated(model))
(val = (loss = 14.52, aux = [4.84, 9.68, 14.52]), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
source
Flux.Train.train!Method
train!(loss, Duplicated(model), data, opt_state)

This method uses Enzyme.jl instead of Zygote.jl to compute the gradients, but is otherwise the same as train!(loss, model, data, opt_state).

Only available when Enzyme is loaded.

New

This method was added in Flux 0.13.9.

source

Enzyme.jl has its own extensive documentation.