Automatic Differentiation using Zygote.jl

Flux re-exports the gradient from Zygote, and uses this function within train! to differentiate the model. Zygote has its own documentation, in particulat listing some limitations.

Implicit style

Flux uses primarily what Zygote calls "implicit" gradients, described here in its documentation.

Zygote.gradientMethod
gradient(f, args...)

Returns a tuple containing ∂f/∂x for each argument x, the derivative (for scalar x) or the gradient.

f(args...) must be a real number, see jacobian for array output.

See also withgradient to keep the value f(args...), and pullback for value and back-propagator.

julia> gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)

julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)

julia> gradient([7, 11], 0, 1) do x, y, d
         p = size(x, d)
         sum(x.^p .+ y)
       end
([14.0, 22.0], 2.0, nothing)
Zygote.ParamsType
Params([A, B])

Container for implicit parameters, used when differentiating a zero-argument funtion () -> loss(A, B) with respect to A, B.

Zygote.GradsType
Grads(...)

Dictionary-like container returned when taking gradients with respect to implicit parameters. For an array W, appearing within Params([W, A, B...]), the gradient is g[W].

Explicit style

The other way of using Zygote, and using most other AD packages, is to explicitly provide a function and its arguments.

Zygote.gradientMethod
gradient(f, args...)

Returns a tuple containing ∂f/∂x for each argument x, the derivative (for scalar x) or the gradient.

f(args...) must be a real number, see jacobian for array output.

See also withgradient to keep the value f(args...), and pullback for value and back-propagator.

julia> gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)

julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)

julia> gradient([7, 11], 0, 1) do x, y, d
         p = size(x, d)
         sum(x.^p .+ y)
       end
([14.0, 22.0], 2.0, nothing)
Zygote.withgradientMethod
withgradient(f, args...)
withgradient(f, ::Params)

Returns both the value of the function and the gradient, as a named tuple.

julia> y, ∇ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))

julia> ∇ == gradient(/, 1, 2)
true
Zygote.jacobianMethod
jacobian(f, args...) -> Tuple

For each array a ∈ args this returns a matrix with Ja[k,i] = ∂y[k]/∂a[i] where y = f(args...) is usually a vector. Arrays of higher dimension are treated like vec(a), or vec(y) for output.

For scalar x::Number ∈ args, the result is a vector Jx[k] = ∂y[k]/∂x, while for scalar y all results have just one row.

With any other argument type, no result is produced, even if gradient would work.

This reverse-mode Jacobian needs to evaluate the pullback once for each element of y. Doing so is usually only efficient when length(y) is small compared to length(a), otherwise forward mode is likely to be better.

See also withjacobian, hessian, hessian_reverse.

Examples

julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1]  # first index (rows) is output
3×7 Matrix{Int64}:
 200    0    0  0  0  0  0
   0  400    0  0  0  0  0
   0    0  600  0  0  0  0

julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1)  # scalar argument has vector jacobian
([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])

julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0])
Warning

For arguments of any type except Number & AbstractArray, the result is nothing.

julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str")
([3 0 0; 0 12 0; 0 0 27], nothing)

julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)

julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))  # gradient undersands the tuple
([4 4 4], (6, 1))

ChainRules

Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using ChainRules:

ChainRulesCore.ignore_derivativesFunction
ignore_derivatives(f::Function)

Tells the AD system to ignore the gradients of the wrapped closure. The primal computation (forward pass) is executed normally.

ignore_derivatives() do
    value = rand()
    push!(collection, value)
end

Using this incorrectly could lead to incorrect gradients. For example, the following function will have zero gradients with respect to its argument:

function wrong_grads(x)
    y = ones(3)
    ignore_derivatives() do
        push!(y, x)
    end
    return sum(y)
end
ignore_derivatives(x)

Tells the AD system to ignore the gradients of the argument. Can be used to avoid unnecessary computation of gradients.

ignore_derivatives(x) * w
ChainRulesCore.@non_differentiableMacro
@non_differentiable(signature_expression)

A helper to make it easier to declare that a method is not differentiable. This is a short-hand for defining an frule and rrule that return NoTangent() for all partials (even for the function s̄elf-partial itself)

Keyword arguments should not be included.

julia> @non_differentiable Base.:(==)(a, b)

julia> _, pullback = rrule(==, 2.0, 3.0);

julia> pullback(1.0)
(NoTangent(), NoTangent(), NoTangent())

You can place type-constraints in the signature:

julia> @non_differentiable Base.length(xs::Union{Number, Array})

julia> frule((ZeroTangent(), 1), length, [2.0, 3.0])
(2, NoTangent())
Warning

This helper macro covers only the simple common cases. It does not support where-clauses. For these you can declare the rrule and frule directly

To manually supply the gradient for one function, you should define a method of rrule. ChainRules has detailed documentation on how this works.

ChainRulesCore.rruleFunction
rrule([::RuleConfig,] f, x...)

Expressing x as the tuple (x₁, x₂, ...) and the output tuple of f(x...) as Ω, return the tuple:

(Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...))

Where the second return value is the the propagation rule or pullback. It takes in cotangents corresponding to the outputs (x̄₁, x̄₂, ...), and s̄elf, the internal values of the function itself (for closures)

If no method matching rrule(f, xs...) has been defined, then return nothing.

Examples:

unary input, unary output scalar function:

julia> x = rand();

julia> sinx, sin_pullback = rrule(sin, x);

julia> sinx == sin(x)
true

julia> sin_pullback(1) == (NoTangent(), cos(x))
true

binary input, unary output scalar function:

julia> x, y = rand(2);

julia> hypotxy, hypot_pullback = rrule(hypot, x, y);

julia> hypotxy == hypot(x, y)
true

julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))
true

The optional RuleConfig option allows specifying rrules only for AD systems that support given features. If not needed, then it can be omitted and the rrule without it will be hit as a fallback. This is the case for most rules.

See also: frule, @scalar_rule, RuleConfig