Automatic Differentiation in Flux

Flux's gradient function uses Zygote by default, and also uses this function within train! to differentiate the model. Zygote has its own documentation, in particular listing some important limitations.

Flux also has support for Enzyme.jl, documented below and for Mooncake.jl.

Generic Gradient Interface

Flux.gradientMethod
gradient(f, [adtype,] args...)

Returns a tuple containing ∂f/∂x for each argument x, the derivative (for scalar x) or the gradient. If no gradient is defined, ∂f/∂x will be nothing.

f(args...) must be a real number, see Zygote.jacobian for array output.

The optional argument adtype allows specifying the automatic differentiation backend.

We provide specific support and testing for the following backends: AutoZygote, AutoEnzyme, AutoMooncake, and AutoFiniteDifferences.

The package corresponding to any chosen backend (except Zygote) must be loaded in advance.

If no adtype is given, then Zygote.jl is used by default, unless at least one argument is of type Duplicated from Enzyme.jl, in which case Enzyme.jl is used.

See also withgradient to keep the value f(args...).

Examples

julia> Flux.gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)

julia> Flux.gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)

julia> Flux.gradient([7, 11], 0, 1) do x, y, d
         p = size(x, d)
         sum(x.^p .+ y)
       end
([14.0, 22.0], 2.0, nothing)

Specifying other AD backends:

julia> using Mooncake

julia> f(x) = sum(2 .* x)
f (generic function with 1 method)

julia> Flux.gradient(f, AutoMooncake(), [1.0, 2.0, 3.0])
([2.0, 2.0, 2.0],)
source
Flux.withgradientMethod
withgradient(f, [adtype,] args...)

Returns both the value of the function and the gradient, as a named tuple.

The optional argument adtype allows specifying the automatic differentiation backend among the supported ones: AutoZygote, AutoEnzyme, AutoMooncake, and AutoFiniteDifferences. The package corresponding to the chosen backend must be loaded in advance.

If no adtype is given, then Zygote.jl is used by default, unless at least one argument is of type Duplicated from Enzyme.jl, in which case Enzyme.jl is used.

Se also gradient to get just the gradient.

Examples

julia> y, ∇ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))

julia> ∇ == gradient(/, 1, 2)
true

withgradient allows you to capture auxillary outputs, in addition to the scalar used by gradient. To do this, f must return a Tuple or NamedTuple. Then it calculates grad = gradient(first∘f, args...) but returns the wholeval = f(args...)`:

julia> withgradient([1,2,4]) do x
          z = 1 ./ x
          sum(z), z  # here z is an auxillary output
       end
(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))

julia> withgradient(3.0, 4.0) do x, y
          (div = x/y, mul = x*y)
       end
(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))

Different AD backends can be specified:

julia> using Mooncake

julia> f(x) = sum(2 .* x)
f (generic function with 1 method)

julia> Flux.withgradient(f, AutoMooncake(), [1.0, 2.0, 3.0])
(val = 12.0, grad = ([2.0, 2.0, 2.0],))
source

Automatic Differentiation using Zygote.jl

The default AD backend in Flux is Zygote. Besides gradient calculation, Zygote also supports higher-order derivatives, Jacobians, Hessians, and pullbacks.

Zygote.jacobianMethod
jacobian(f, args...) -> Tuple

For each array a ∈ args this returns a matrix with Ja[k,i] = ∂y[k]/∂a[i] where y = f(args...) is usually a vector. Arrays of higher dimension are treated like vec(a), or vec(y) for output.

For scalar x::Number ∈ args, the result is a vector Jx[k] = ∂y[k]/∂x, while for scalar y all results have just one row.

With any other argument type, no result is produced, even if gradient would work.

This reverse-mode Jacobian needs to evaluate the pullback once for each element of y. Doing so is usually only efficient when length(y) is small compared to length(a), otherwise forward mode is likely to be better.

See also withjacobian, hessian, hessian_reverse.

Examples

julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1]  # first index (rows) is output
3×7 Matrix{Int64}:
 200    0    0  0  0  0  0
   0  400    0  0  0  0  0
   0    0  600  0  0  0  0

julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1)  # scalar argument has vector jacobian
([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])

julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0])
Warning

For arguments of any type except Number & AbstractArray, the result is nothing.

julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str")
([3 0 0; 0 12 0; 0 0 27], nothing)

julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)

julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))  # gradient undersands the tuple
([4 4 4], (6, 1))
source
Zygote.withjacobianMethod
withjacobian(f, args...)

Returns both the value f(args...) and the jacobian as a named tuple.

julia> withjacobian(cumsum, [1,2,3])
(val = [1, 3, 6], grad = ([1 0 0; 1 1 0; 1 1 1],))
source
Zygote.hessianFunction
hessian(f, x)

Construct the Hessian ∂²f/∂x², where x is a real number or an array, and f(x) is a real number. When x is an array, the result is a matrix H[i,j] = ∂²f/∂x[i]∂x[j], using linear indexing x[i] even if the argument is higher-dimensional.

This uses forward over reverse, ForwardDiff over Zygote, calling hessian_dual(f, x). See hessian_reverse for an all-Zygote alternative.

See also diaghessian to compute only the diagonal part.

Examples

julia> hessian(x -> x[1]*x[2], randn(2))
2×2 Matrix{Float64}:
 0.0  1.0
 1.0  0.0

julia> hessian(x -> sum(x.^3), [1 2; 3 4])  # uses linear indexing of x
4×4 Matrix{Int64}:
 6   0   0   0
 0  18   0   0
 0   0  12   0
 0   0   0  24

julia> hessian(sin, pi/2)
-1.0
source
Zygote.hessian_reverseFunction
hessian_reverse(f, x)

This should be equivalent to hessian(f, x), but implemented using reverse over reverse mode, all Zygote. (This is usually much slower, and more likely to find errors.)

source
Zygote.diaghessianFunction
diaghessian(f, args...) -> Tuple

Diagonal part of the Hessian. Returns a tuple containing, for each argument x, h of the same shape with h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]. The original evaluation y = f(args...) must give a real number y.

For one vector argument x, this is equivalent to (diag(hessian(f,x)),). Like hessian it uses ForwardDiff over Zygote.

Warning

For arguments of any type except Number & AbstractArray, the result is nothing.

Examples

julia> diaghessian(x -> sum(x.^3), [1 2; 3 4])[1]
2×2 Matrix{Int64}:
  6  12
 18  24

julia> Diagonal(vec(ans)) == hessian(x -> sum(x.^3), [1 2; 3 4])  # full Hessian is diagonal
true

julia> diaghessian((x,y) -> sum(x .* y .* y'), [1 22; 333 4], [0.5, 0.666])  # two array arguments
([0.0 0.0; 0.0 0.0], [2.0, 8.0])

julia> diaghessian(atan, 1, 2)  # two scalar arguments
(-0.16, 0.16)

julia> hessian(xy -> atan(xy[1], xy[2]), [1, 2])  # full Hessian is not diagonal
2×2 Matrix{Float64}:
 -0.16  -0.12
 -0.12   0.16
source
ZygoteRules.pullbackFunction
pullback(f, args...)
pullback(f, ::Params)

Returns the value of the function f and a back-propagator function, which can be called to obtain a tuple containing ∂f/∂x for each argument x, the derivative (for scalar x) or gradient.

y, back = pullback(f, args...)
∇ = back(seed)

back must be called with a start value seed matching the output of f(args...). If f(args...) returns a number, seed should be a number. If f(args...) returns an array, seed should be an equally-sized array.

See also withgradient to obtain the value and gradients in one call, and gradient for obtaining just the gradients.

julia> y, back = pullback(*, 2.0, 3.0, 5.0);

julia> y
30.0

julia> back(1.0)
(15.0, 10.0, 6.0)

julia> back(2.0)
(30.0, 20.0, 12.0)

julia> y, back = pullback(x -> [x, x], 1.0);

julia> y
2-element Vector{Float64}:
 1.0
 1.0

julia> back([1.0, 1.0])
(2.0,)

julia> back([2.0, nothing])
(2.0,)
source

ChainRules for Zygote

Zygote uses ChainRules.jl to define how to differentiate functions.

Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using the following methods:

ChainRulesCore.ignore_derivativesFunction
ignore_derivatives(f::Function)

Tells the AD system to ignore the gradients of the wrapped closure. The primal computation (forward pass) is executed normally.

ignore_derivatives() do
    value = rand()
    push!(collection, value)
end

Using this incorrectly could lead to incorrect gradients. For example, the following function will have zero gradients with respect to its argument:

function wrong_grads(x)
    y = ones(3)
    ignore_derivatives() do
        push!(y, x)
    end
    return sum(y)
end
source
ignore_derivatives(x)

Tells the AD system to ignore the gradients of the argument. Can be used to avoid unnecessary computation of gradients.

ignore_derivatives(x) * w
source
ChainRulesCore.@non_differentiableMacro
@non_differentiable(signature_expression)

A helper to make it easier to declare that a method is not differentiable. This is a short-hand for defining an frule and rrule that return NoTangent() for all partials (even for the function s̄elf-partial itself)

Keyword arguments should not be included.

julia> @non_differentiable Base.:(==)(a, b)

julia> _, pullback = rrule(==, 2.0, 3.0);

julia> pullback(1.0)
(NoTangent(), NoTangent(), NoTangent())

You can place type-constraints in the signature:

julia> @non_differentiable Base.length(xs::Union{Number, Array})

julia> frule((ZeroTangent(), 1), length, [2.0, 3.0])
(2, NoTangent())
Warning

This helper macro covers only the simple common cases. It does not support where-clauses. For these you can declare the rrule and frule directly

source

To manually supply the gradient for one function, you should define a method of rrule. ChainRules has detailed documentation on how this works.

ChainRulesCore.rruleFunction
rrule([::RuleConfig,] f, x...)

Expressing x as the tuple (x₁, x₂, ...) and the output tuple of f(x...) as Ω, return the tuple:

(Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...))

Where the second return value is the the propagation rule or pullback. It takes in cotangents corresponding to the outputs (x̄₁, x̄₂, ...), and s̄elf, the internal values of the function itself (for closures)

If no method matching rrule(f, xs...) has been defined, then return nothing.

Examples:

unary input, unary output scalar function:

julia> x = rand();

julia> sinx, sin_pullback = rrule(sin, x);

julia> sinx == sin(x)
true

julia> sin_pullback(1) == (NoTangent(), cos(x))
true

binary input, unary output scalar function:

julia> x, y = rand(2);

julia> hypotxy, hypot_pullback = rrule(hypot, x, y);

julia> hypotxy == hypot(x, y)
true

julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))
true

The optional RuleConfig option allows specifying rrules only for AD systems that support given features. If not needed, then it can be omitted and the rrule without it will be hit as a fallback. This is the case for most rules.

See also: frule, @scalar_rule, RuleConfig

source
ChainRulesCore.fruleFunction
frule([::RuleConfig,] (Δf, Δx...), f, x...)

Expressing the output of f(x...) as Ω, return the tuple:

(Ω, ΔΩ)

The second return value is the tangent w.r.t. the output.

If no method matching frule((Δf, Δx...), f, x...) has been defined, then return nothing.

Examples:

unary input, unary output scalar function:

julia> dself = NoTangent();

julia> x = rand()
0.8236475079774124

julia> sinx, Δsinx = frule((dself, 1), sin, x)
(0.7336293678134624, 0.6795498147167869)

julia> sinx == sin(x)
true

julia> Δsinx == cos(x)
true

Unary input, binary output scalar function:

julia> sincosx, Δsincosx = frule((dself, 1), sincos, x);

julia> sincosx == sincos(x)
true

julia> Δsincosx[1] == cos(x)
true

julia> Δsincosx[2] == -sin(x)
true

Note that techically speaking julia does not have multiple output functions, just functions that return a single output that is iterable, like a Tuple. So this is actually a Tangent:

julia> Δsincosx
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)

The optional RuleConfig option allows specifying frules only for AD systems that support given features. If not needed, then it can be omitted and the frule without it will be hit as a fallback. This is the case for most rules.

See also: rrule, @scalar_rule, RuleConfig

source
ChainRulesCore.@scalar_ruleMacro
@scalar_rule(f(x₁, x₂, ...),
             @setup(statement₁, statement₂, ...),
             (∂f₁_∂x₁, ∂f₁_∂x₂, ...),
             (∂f₂_∂x₁, ∂f₂_∂x₂, ...),
             ...)

A convenience macro that generates simple scalar forward or reverse rules using the provided partial derivatives. Specifically, generates the corresponding methods for frule and rrule:

function ChainRulesCore.frule((NoTangent(), Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...)
    Ω = f(x₁, x₂, ...)
    $(statement₁, statement₂, ...)
    return Ω, (
            (∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
            (∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
            ...
        )
end

function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
    Ω = f(x₁, x₂, ...)
    $(statement₁, statement₂, ...)
    return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> (
            NoTangent(),
            ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
            ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
            ...
        )
end

If no type constraints in f(x₁, x₂, ...) within the call to @scalar_rule are provided, each parameter in the resulting frule/rrule definition is given a type constraint of Number. Constraints may also be explicitly be provided to override the Number constraint, e.g. f(x₁::Complex, x₂), which will constrain x₁ to Complex and x₂ to Number.

At present this does not support defining for closures/functors. Thus in reverse-mode, the first returned partial, representing the derivative with respect to the function itself, is always NoTangent(). And in forward-mode, the first input to the returned propagator is always ignored.

The result of f(x₁, x₂, ...) is automatically bound to Ω. This allows the primal result to be conveniently referenced (as Ω) within the derivative/setup expressions.

This macro assumes complex functions are holomorphic. In general, for non-holomorphic functions, the frule and rrule must be defined manually.

If the derivative is one, (e.g. for identity functions) true can be used as the most general multiplicative identity.

The @setup argument can be elided if no setup code is need. In other words:

@scalar_rule(f(x₁, x₂, ...),
             (∂f₁_∂x₁, ∂f₁_∂x₂, ...),
             (∂f₂_∂x₁, ∂f₂_∂x₂, ...),
             ...)

is equivalent to:

@scalar_rule(f(x₁, x₂, ...),
             @setup(nothing),
             (∂f₁_∂x₁, ∂f₁_∂x₂, ...),
             (∂f₂_∂x₁, ∂f₂_∂x₂, ...),
             ...)

For examples, see ChainRules' rulesets directory.

See also: frule, rrule.

source
ChainRulesCore.NoTangentType
NoTangent() <: AbstractZero

This tangent indicates that the derivative does not exist. It is the tangent type for primal types that are not differentiable, such as integers or booleans (when they are not being used to represent floating-point values). The only valid way to perturb such values is to not change them at all. As a consequence, NoTangent is functionally identical to ZeroTangent(), but it provides additional semantic information.

Adding NoTangent() to a primal is generally wrong: gradient-based methods cannot be used to optimize over discrete variables. An optimization package making use of this might want to check for such a case.

Note

This does not indicate that the derivative is not implemented, but rather that mathematically it is not defined.

This mostly shows up as the derivative with respect to dimension, index, or size arguments.

function rrule(fill, x, len::Int)
    y = fill(x, len)
    fill_pullback(ȳ) = (NoTangent(), @thunk(sum(Ȳ)), NoTangent())
    return y, fill_pullback
end
source
ChainRulesCore.ZeroTangentType
ZeroTangent() <: AbstractZero

The additive identity for tangents. This is basically the same as 0. A derivative of ZeroTangent() does not propagate through the primal function.

source
ChainRulesCore.RuleConfigType
RuleConfig{T}

The configuration for what rules to use. T: traits. This should be a Union of all special traits needed for rules to be allowed to be defined for your AD. If nothing special this should be set to Union{}.

AD authors should define a subtype of RuleConfig to use when calling frule/rrule.

Rule authors can dispatch on this config when defining rules. For example:

# only define rrule for `pop!` on AD systems where mutation is supported.
rrule(::RuleConfig{>:SupportsMutation}, typeof(pop!), ::Vector) = ...

# this definition of map is for any AD that defines a forwards mode
rrule(conf::RuleConfig{>:HasForwardsMode}, typeof(map), ::Vector) = ...

# this definition of map is for any AD that only defines a reverse mode.
# It is not as good as the rrule that can be used if the AD defines a forward-mode as well.
rrule(conf::RuleConfig{>:Union{NoForwardsMode, HasReverseMode}}, typeof(map), ::Vector) = ...

For more details see rule configurations and calling back into AD.

source
ChainRulesCore.TangentType
Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent

This type represents the tangent for a struct/NamedTuple, or Tuple. P is the the corresponding primal type that this is a tangent for.

Tangent{P} should have fields (technically properties), that match to a subset of the fields of the primal type; and each should be a tangent type matching to the primal type of that field. Fields of the P that are not present in the Tangent are treated as Zero.

T is an implementation detail representing the backing data structure. For Tuple it will be a Tuple, and for everything else it will be a NamedTuple. It should not be passed in by user.

For Tangents of Tuples, iterate and getindex are overloaded to behave similarly to for a tuple. For Tangents of structs, getproperty is overloaded to allow for accessing values via tangent.fieldname. Any fields not explictly present in the Tangent are treated as being set to ZeroTangent(). To make a Tangent have all the fields of the primal the canonicalize function is provided.

source
ChainRulesCore.canonicalizeFunction
canonicalize(tangent::Tangent{P}) -> Tangent{P}

Return the canonical Tangent for the primal type P. The property names of the returned Tangent match the field names of the primal, and all fields of P not present in the input tangent are explictly set to ZeroTangent().

source

Gradient customization for other AD packages such as Enzyme and Mooncake has to be done according to their own documentation.

Automatic Differentiation using Enzyme.jl

Enzyme.jl is a new package for automatic differentiation. Like Zygote.jl, calling gradient(f, x) causes it to hooks into the compiler and transform code that is executed while calculating f(x), in order to produce code for ∂f/∂x. But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR) which you can read about here]. It needs far fewer custom rules than Zygote/ChainRules, and in particular is able to support mutation of arrays.

Flux now builds in support for this, using Enzyme's own Duplicated type. Calling Duplicated on any Flux model which was defined using @layer will allocate space for the gradient, and passing that to gradient (or withgradient, or train!) will then use Enzyme instead of Zygote. The gradient functions still return the gradient as usual, which can then be passed to update!:

julia> using Flux, Enzyme

julia> model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax);  # from model zoo

julia> dup_model = Enzyme.Duplicated(model)  # this allocates space for the gradient
Duplicated(
  Chain(
    Dense(784 => 32, σ),                # 25_120 parameters
    Dense(32 => 10),                    # 330 parameters
    NNlib.softmax,
  ),
  # norm(∇) ≈ 0.0f0
)                   # Total: 4 arrays, 25_450 parameters, 199.391 KiB.

julia> x1 = randn32(28*28, 1);  # fake image

julia> y1 = [i==3 for i in 0:9];  # fake label

julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1))  # uses Enzyme
((layers = ((weight = Float32[-0.010354728 0.032972857 …
    -0.0014538406], σ = nothing), nothing),), nothing, nothing)

The gradient returned here is also stored within dup_model. Both share the same arrays – what is returned is not a copy, just a view of the same memory (wrapped in NamedTuples instead of structs). They will all be set to zero when you call gradient again, then replaced with the new values. Alternatively, gradient(f, args...; zero=false) will add the new gradient to what's already stored.

Writing Const(x1) is optional, just plain x1 is implicitly constant. Any set of Duplicated and Const arguments may appear in any order, so long as there is at least one Duplicated.

The gradient grads_f[1] can be passed to update! as usual. But for convenience, you may also use what is stored within Duplicated. These are equivalent ways to perform an update step:

julia> opt_state = Flux.setup(Adam(), model)

julia> ans == Flux.setup(Adam(), dup_model)

julia> Flux.update!(opt_state, model, grads_f[1])  # exactly as for Zygote gradients

julia> Flux.update!(opt_state, dup_model)  # equivlent new path, Enzyme only

Instead of using these FLux functions, you can also use Enzyme's own functions directly. Enzyme.gradient works like this:

julia> grads_e = Enzyme.gradient(Reverse, (m,x,y) -> sum(abs2, m(x) .- y), model, Const(x1), Const(y1))
(Chain(Dense(784 => 32, σ), Dense(32 => 10), softmax), nothing, nothing)

julia> grads_f[1].layers[2].bias ≈ grads_e[1].layers[2].bias
true

Note that what Enzyme.gradient returns is an object like deepcopy(model) of the same type, grads_e[1] isa Chain. But its fields contain the same gradient.

Flux.gradientMethod
gradient(f, args::Union{Any,EnzymeCore.Duplicated}...)

This should return the same answer as gradient(f, args...), but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.

Only available when Enzyme is loaded!

This method is used when at least one argument is of type Duplicated, All non-duplicated arguments are treated as Const. Note that Enzyme's Active is not supported.

Besides returning the gradient, this is also stored within the Duplicated object. Calling Enzyme.Duplicated(model) allocates space for the gradient, which is zero'd befor use when calling gradient. With the keyword zero=false, the new gradient will instead be added to what is already stored.

Examples

julia> using Flux

julia> model = Chain(Dense([3.0;;]));

julia> Flux.gradient(model, [1]) do m, x  # computed using Zygote
         sum(abs2, m(x))
       end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), [18.0])

julia> using Enzyme

julia> dup_model = Duplicated(model);  # allocates space for gradient

julia> Flux.gradient(dup_model, Const([1])) do m, x  # Enzyme, returns the same
         sum(abs2, m(x))
       end
((layers = ((weight = [6.0;;], bias = [6.0], σ = nothing),),), nothing)

julia> dup_model  # same gradient is also stored within Duplicated
Duplicated(
  Chain(
    Dense(1 => 1),                      # 2 parameters
  ),
  # norm(∇) ≈ 8.49
)

julia> Flux.destructure((weight = [6.0;;], bias = [6.0]))[1] |> norm
8.48528137423857

julia> Flux.gradient(dup_model, [1]; zero=false) do m, x  # implict Const([1]), and grad accumulation
         sum(abs2, m(x))
       end
((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing)
source
Flux.withgradientMethod
withgradient(f, args::Union{Any,EnzymeCore.Duplicated}...)

This should return the same answer as withgradient(f, model, args...), but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.

Only available when Enzyme is loaded!

This method is used when at least one argument is of type Duplicated, All non-duplicated arguments will be differentiated as well. Mark them as Const to avoid this. Note that Enzyme's Active is not supported.

Examples

julia> using Flux, Enzyme

julia> model = Chain(Embedding([1.1 2.2 3.3]), Dense([4.4;;]), only);

julia> model(3)
14.52

julia> Flux.withgradient(m -> m(3), model)  # this uses Zygote
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))

julia> Flux.withgradient(m -> m(3), Duplicated(model))  # this uses Enzyme
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
source

Enzyme.jl has its own extensive documentation.

Second-order AD

If you calculate a gradient within the loss function, then training will involve 2nd derivatives. While this is in principle supported by Zygote.jl, there are many bugs, and Enzyme.jl is probably a better choice.