Automatic Differentiation using Zygote.jl

Flux re-exports the gradient from Zygote, and uses this function within train! to differentiate the model. Zygote has its own documentation, in particular listing some important limitations.

Implicit style

Flux uses primarily what Zygote calls "implicit" gradients, described here in its documentation.

Zygote.gradient โ€” Function
gradient(() -> loss(), ps::Params) -> Grads

Gradient with implicit parameters. Takes a zero-argument function, and returns a dictionary-like container, whose keys are arrays x in ps.

julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];

julia> g = gradient(Params([x, y])) do
         sum(x .* y .* z')

julia> g[x]
2ร—3 Matrix{Float64}:
 7.0  70.0  700.0
 8.0  80.0  800.0

julia> haskey(g, z)  # only x and y are parameters
Zygote.Params โ€” Type
Params([A, B])

Container for implicit parameters, used when differentiating a zero-argument funtion () -> loss(A, B) with respect to A, B.

Zygote.Grads โ€” Type

Dictionary-like container returned when taking gradients with respect to implicit parameters. For an array W, appearing within Params([W, A, B...]), the gradient is g[W].

Zygote.jacobian โ€” Method
jacobian(loss, ::Params)

Like gradient with implicit parameters, this method takes a zero-argument function and returns an IdDict-like object, now containing the Jacobian for each parameter.


julia> xs = [1 2; 3 4]; ys = [5,7,9];

julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))

julia> Jxy[ys]
2ร—3 Matrix{Int64}:
 1  0  0
 0  1  0

julia> Jxy[xs]
2ร—4 Matrix{Int64}:
 2  6  4  8
 2  6  4  8

Explicit style

The other way of using Zygote, and using most other AD packages, is to explicitly provide a function and its arguments.

Zygote.gradient โ€” Method
gradient(f, args...)

Returns a tuple containing โˆ‚f/โˆ‚x for each argument x, the derivative (for scalar x) or the gradient.

f(args...) must be a real number, see jacobian for array output.

See also withgradient to keep the value f(args...), and pullback for value and back-propagator.

julia> gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)

julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)

julia> gradient([7, 11], 0, 1) do x, y, d
         p = size(x, d)
         sum(x.^p .+ y)
([14.0, 22.0], 2.0, nothing)
Zygote.withgradient โ€” Method
withgradient(f, args...)
withgradient(f, ::Params)

Returns both the value of the function and the gradient, as a named tuple.

julia> y, โˆ‡ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))

julia> โˆ‡ == gradient(/, 1, 2)
Zygote.jacobian โ€” Method
jacobian(f, args...) -> Tuple

For each array a โˆˆ args this returns a matrix with Ja[k,i] = โˆ‚y[k]/โˆ‚a[i] where y = f(args...) is usually a vector. Arrays of higher dimension are treated like vec(a), or vec(y) for output.

For scalar x::Number โˆˆ args, the result is a vector Jx[k] = โˆ‚y[k]/โˆ‚x, while for scalar y all results have just one row.

With any other argument type, no result is produced, even if gradient would work.

This reverse-mode Jacobian needs to evaluate the pullback once for each element of y. Doing so is usually only efficient when length(y) is small compared to length(a), otherwise forward mode is likely to be better.

See also withjacobian, hessian, hessian_reverse.


julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1]  # first index (rows) is output
3ร—7 Matrix{Int64}:
 200    0    0  0  0  0  0
   0  400    0  0  0  0  0
   0    0  600  0  0  0  0

julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1)  # scalar argument has vector jacobian
([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])

julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
([2 0 โ€ฆ 0 0; 0 4 โ€ฆ 3 0; 0 0 โ€ฆ 0 5], [0, 0, 0])

For arguments of any type except Number & AbstractArray, the result is nothing.

julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str")
([3 0 0; 0 12 0; 0 0 27], nothing)

julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)

julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))  # gradient undersands the tuple
([4 4 4], (6, 1))
Zygote.withgradient โ€” Function
withgradient(f, args...)
withgradient(f, ::Params)

Returns both the value of the function and the gradient, as a named tuple.

julia> y, โˆ‡ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))

julia> โˆ‡ == gradient(/, 1, 2)


Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using ChainRules:

ChainRulesCore.ignore_derivatives โ€” Function

Tells the AD system to ignore the gradients of the wrapped closure. The primal computation (forward pass) is executed normally.

ignore_derivatives() do
    value = rand()
    push!(collection, value)

Using this incorrectly could lead to incorrect gradients. For example, the following function will have zero gradients with respect to its argument:

function wrong_grads(x)
    y = ones(3)
    ignore_derivatives() do
        push!(y, x)
    return sum(y)

Tells the AD system to ignore the gradients of the argument. Can be used to avoid unnecessary computation of gradients.

ignore_derivatives(x) * w
ChainRulesCore.@non_differentiable โ€” Macro

A helper to make it easier to declare that a method is not differentiable. This is a short-hand for defining an frule and rrule that return NoTangent() for all partials (even for the function sฬ„elf-partial itself)

Keyword arguments should not be included.

julia> @non_differentiable Base.:(==)(a, b)

julia> _, pullback = rrule(==, 2.0, 3.0);

julia> pullback(1.0)
(NoTangent(), NoTangent(), NoTangent())

You can place type-constraints in the signature:

julia> @non_differentiable Base.length(xs::Union{Number, Array})

julia> frule((ZeroTangent(), 1), length, [2.0, 3.0])
(2, NoTangent())

This helper macro covers only the simple common cases. It does not support where-clauses. For these you can declare the rrule and frule directly

To manually supply the gradient for one function, you should define a method of rrule. ChainRules has detailed documentation on how this works.

ChainRulesCore.rrule โ€” Function
rrule([::RuleConfig,] f, x...)

Expressing x as the tuple (xโ‚, xโ‚‚, ...) and the output tuple of f(x...) as ฮฉ, return the tuple:

(ฮฉ, (ฮฉฬ„โ‚, ฮฉฬ„โ‚‚, ...) -> (sฬ„elf, xฬ„โ‚, xฬ„โ‚‚, ...))

Where the second return value is the the propagation rule or pullback. It takes in cotangents corresponding to the outputs (xฬ„โ‚, xฬ„โ‚‚, ...), and sฬ„elf, the internal values of the function itself (for closures)

If no method matching rrule(f, xs...) has been defined, then return nothing.


unary input, unary output scalar function:

julia> x = rand();

julia> sinx, sin_pullback = rrule(sin, x);

julia> sinx == sin(x)

julia> sin_pullback(1) == (NoTangent(), cos(x))

binary input, unary output scalar function:

julia> x, y = rand(2);

julia> hypotxy, hypot_pullback = rrule(hypot, x, y);

julia> hypotxy == hypot(x, y)

julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))

The optional RuleConfig option allows specifying rrules only for AD systems that support given features. If not needed, then it can be omitted and the rrule without it will be hit as a fallback. This is the case for most rules.

See also: frule, @scalar_rule, RuleConfig