# Automatic Differentiation using Zygote.jl

Flux re-exports the `gradient`

from Zygote, and uses this function within `train!`

to differentiate the model. Zygote has its own documentation, in particular listing some important limitations.

## Explicit style

The preferred way of using Zygote, and the only way of using most other AD packages, is to explicitly provide a function and its arguments.

`Zygote.gradient`

— Method`gradient(f, args...)`

Returns a tuple containing `∂f/∂x`

for each argument `x`

, the derivative (for scalar `x`

) or the gradient. If no gradient is defined, `∂f/∂x`

will be `nothing`

.

`f(args...)`

must be a real number, see `jacobian`

for array output.

See also `withgradient`

to keep the value `f(args...)`

, and `pullback`

for value and back-propagator.

```
julia> gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)
julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)
julia> gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
sum(x.^p .+ y)
end
([14.0, 22.0], 2.0, nothing)
```

`Zygote.withgradient`

— Method```
withgradient(f, args...)
withgradient(f, ::Params)
```

Returns both the value of the function and the `gradient`

, as a named tuple.

```
julia> y, ∇ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
julia> ∇ == gradient(/, 1, 2)
true
```

Allows you to capture auxillary outputs, in addition to the scalar used by `gradient`

. To do this, `f`

must return a Tuple or NamedTuple. Then it calculates `grad = gradient(first∘f, args...) but returns the whole`

val = f(args...)`:

```
julia> withgradient([1,2,4]) do x
z = 1 ./ x
sum(z), z # here z is an auxillary output
end
(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))
julia> withgradient(3.0, 4.0) do x, y
(div = x/y, mul = x*y)
end
(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))
```

Also supports implicit mode:

```
julia> w = [3.0];
julia> res = withgradient(() -> sum(abs2, w), Params([w]))
(val = 9.0, grad = Grads(...))
julia> res.grad[w]
1-element Vector{Float64}:
6.0
```

`Zygote.jacobian`

— Method`jacobian(f, args...) -> Tuple`

For each array `a ∈ args`

this returns a matrix with `Ja[k,i] = ∂y[k]/∂a[i]`

where `y = f(args...)`

is usually a vector. Arrays of higher dimension are treated like `vec(a)`

, or `vec(y)`

for output.

For scalar `x::Number ∈ args`

, the result is a vector `Jx[k] = ∂y[k]/∂x`

, while for scalar `y`

all results have just one row.

With any other argument type, no result is produced, even if `gradient`

would work.

This reverse-mode Jacobian needs to evaluate the pullback once for each element of `y`

. Doing so is usually only efficient when `length(y)`

is small compared to `length(a)`

, otherwise forward mode is likely to be better.

See also `withjacobian`

, `hessian`

, `hessian_reverse`

.

**Examples**

```
julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output
3×7 Matrix{Int64}:
200 0 0 0 0 0 0
0 400 0 0 0 0 0
0 0 600 0 0 0 0
julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian
([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])
julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0])
```

For arguments of any type except `Number`

& `AbstractArray`

, the result is `nothing`

.

```
julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str")
([3 0 0; 0 12 0; 0 0 27], nothing)
julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)
julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple
([4 4 4], (6, 1))
```

`Zygote.withjacobian`

— Method`withjacobian(f, args...)`

Returns both the value `f(args...)`

and the `jacobian`

as a named tuple.

```
julia> withjacobian(cumsum, [1,2,3])
(val = [1, 3, 6], grad = ([1 0 0; 1 1 0; 1 1 1],))
```

`Zygote.hessian`

— Function`hessian(f, x)`

Construct the Hessian `∂²f/∂x²`

, where `x`

is a real number or an array, and `f(x)`

is a real number. When `x`

is an array, the result is a matrix `H[i,j] = ∂²f/∂x[i]∂x[j]`

, using linear indexing `x[i]`

even if the argument is higher-dimensional.

This uses forward over reverse, ForwardDiff over Zygote, calling `hessian_dual(f, x)`

. See `hessian_reverse`

for an all-Zygote alternative.

See also `diaghessian`

to compute only the diagonal part.

**Examples**

```
julia> hessian(x -> x[1]*x[2], randn(2))
2×2 Matrix{Float64}:
0.0 1.0
1.0 0.0
julia> hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x
4×4 Matrix{Int64}:
6 0 0 0
0 18 0 0
0 0 12 0
0 0 0 24
julia> hessian(sin, pi/2)
-1.0
```

`Zygote.hessian_reverse`

— Function`hessian_reverse(f, x)`

This should be equivalent to `hessian(f, x)`

, but implemented using reverse over reverse mode, all Zygote. (This is usually much slower, and more likely to find errors.)

`Zygote.diaghessian`

— Function`diaghessian(f, args...) -> Tuple`

Diagonal part of the Hessian. Returns a tuple containing, for each argument `x`

, `h`

of the same shape with `h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]`

. The original evaluation `y = f(args...)`

must give a real number `y`

.

For one vector argument `x`

, this is equivalent to `(diag(hessian(f,x)),)`

. Like `hessian`

it uses ForwardDiff over Zygote.

For arguments of any type except `Number`

& `AbstractArray`

, the result is `nothing`

.

**Examples**

```
julia> diaghessian(x -> sum(x.^3), [1 2; 3 4])[1]
2×2 Matrix{Int64}:
6 12
18 24
julia> Diagonal(vec(ans)) == hessian(x -> sum(x.^3), [1 2; 3 4]) # full Hessian is diagonal
true
julia> diaghessian((x,y) -> sum(x .* y .* y'), [1 22; 333 4], [0.5, 0.666]) # two array arguments
([0.0 0.0; 0.0 0.0], [2.0, 8.0])
julia> diaghessian(atan, 1, 2) # two scalar arguments
(-0.16, 0.16)
julia> hessian(xy -> atan(xy[1], xy[2]), [1, 2]) # full Hessian is not diagonal
2×2 Matrix{Float64}:
-0.16 -0.12
-0.12 0.16
```

`ZygoteRules.pullback`

— Function```
pullback(f, args...)
pullback(f, ::Params)
```

Returns the value of the function `f`

and a back-propagator function, which can be called to obtain a tuple containing `∂f/∂x`

for each argument `x`

, the derivative (for scalar `x`

) or gradient.

```
y, back = pullback(f, args...)
∇ = back(seed)
```

`back`

must be called with a start value `seed`

matching the output of `f(args...)`

. If `f(args...)`

returns a number, `seed`

should be a number. If `f(args...)`

returns an array, `seed`

should be an equally-sized array.

See also `withgradient`

to obtain the value and gradients in one call, and `gradient`

for obtaining just the gradients.

```
julia> y, back = pullback(*, 2.0, 3.0, 5.0);
julia> y
30.0
julia> back(1.0)
(15.0, 10.0, 6.0)
julia> back(2.0)
(30.0, 20.0, 12.0)
julia> y, back = pullback(x -> [x, x], 1.0);
julia> y
2-element Vector{Float64}:
1.0
1.0
julia> back([1.0, 1.0])
(2.0,)
julia> back([2.0, nothing])
(2.0,)
```

## ChainRules

Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using ChainRules:

`ChainRulesCore.ignore_derivatives`

— Function`ignore_derivatives(f::Function)`

Tells the AD system to ignore the gradients of the wrapped closure. The primal computation (forward pass) is executed normally.

```
ignore_derivatives() do
value = rand()
push!(collection, value)
end
```

Using this incorrectly could lead to incorrect gradients. For example, the following function will have zero gradients with respect to its argument:

```
function wrong_grads(x)
y = ones(3)
ignore_derivatives() do
push!(y, x)
end
return sum(y)
end
```

`ignore_derivatives(x)`

Tells the AD system to ignore the gradients of the argument. Can be used to avoid unnecessary computation of gradients.

`ignore_derivatives(x) * w`

`ChainRulesCore.@non_differentiable`

— Macro`@non_differentiable(signature_expression)`

A helper to make it easier to declare that a method is not differentiable. This is a short-hand for defining an `frule`

and `rrule`

that return `NoTangent()`

for all partials (even for the function `s̄elf`

-partial itself)

Keyword arguments should not be included.

```
julia> @non_differentiable Base.:(==)(a, b)
julia> _, pullback = rrule(==, 2.0, 3.0);
julia> pullback(1.0)
(NoTangent(), NoTangent(), NoTangent())
```

You can place type-constraints in the signature:

```
julia> @non_differentiable Base.length(xs::Union{Number, Array})
julia> frule((ZeroTangent(), 1), length, [2.0, 3.0])
(2, NoTangent())
```

This helper macro covers only the simple common cases. It does not support `where`

-clauses. For these you can declare the `rrule`

and `frule`

directly

To manually supply the gradient for one function, you should define a method of `rrule`

. ChainRules has detailed documentation on how this works.

`ChainRulesCore.rrule`

— Function`rrule([::RuleConfig,] f, x...)`

Expressing `x`

as the tuple `(x₁, x₂, ...)`

and the output tuple of `f(x...)`

as `Ω`

, return the tuple:

`(Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...))`

Where the second return value is the the propagation rule or pullback. It takes in cotangents corresponding to the outputs (`x̄₁, x̄₂, ...`

), and `s̄elf`

, the internal values of the function itself (for closures)

If no method matching `rrule(f, xs...)`

has been defined, then return `nothing`

.

Examples:

unary input, unary output scalar function:

```
julia> x = rand();
julia> sinx, sin_pullback = rrule(sin, x);
julia> sinx == sin(x)
true
julia> sin_pullback(1) == (NoTangent(), cos(x))
true
```

binary input, unary output scalar function:

```
julia> x, y = rand(2);
julia> hypotxy, hypot_pullback = rrule(hypot, x, y);
julia> hypotxy == hypot(x, y)
true
julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))
true
```

The optional `RuleConfig`

option allows specifying rrules only for AD systems that support given features. If not needed, then it can be omitted and the `rrule`

without it will be hit as a fallback. This is the case for most rules.

See also: `frule`

, `@scalar_rule`

, `RuleConfig`

`ChainRulesCore.frule`

— Function`frule([::RuleConfig,] (Δf, Δx...), f, x...)`

Expressing the output of `f(x...)`

as `Ω`

, return the tuple:

`(Ω, ΔΩ)`

The second return value is the tangent w.r.t. the output.

If no method matching `frule((Δf, Δx...), f, x...)`

has been defined, then return `nothing`

.

Examples:

unary input, unary output scalar function:

```
julia> dself = NoTangent();
julia> x = rand()
0.8236475079774124
julia> sinx, Δsinx = frule((dself, 1), sin, x)
(0.7336293678134624, 0.6795498147167869)
julia> sinx == sin(x)
true
julia> Δsinx == cos(x)
true
```

Unary input, binary output scalar function:

```
julia> sincosx, Δsincosx = frule((dself, 1), sincos, x);
julia> sincosx == sincos(x)
true
julia> Δsincosx[1] == cos(x)
true
julia> Δsincosx[2] == -sin(x)
true
```

Note that techically speaking julia does not have multiple output functions, just functions that return a single output that is iterable, like a `Tuple`

. So this is actually a `Tangent`

:

```
julia> Δsincosx
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)
```

The optional `RuleConfig`

option allows specifying frules only for AD systems that support given features. If not needed, then it can be omitted and the `frule`

without it will be hit as a fallback. This is the case for most rules.

See also: `rrule`

, `@scalar_rule`

, `RuleConfig`

`ChainRulesCore.@scalar_rule`

— Macro```
@scalar_rule(f(x₁, x₂, ...),
@setup(statement₁, statement₂, ...),
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)
```

A convenience macro that generates simple scalar forward or reverse rules using the provided partial derivatives. Specifically, generates the corresponding methods for `frule`

and `rrule`

:

```
function ChainRulesCore.frule((NoTangent(), Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
$(statement₁, statement₂, ...)
return Ω, (
(∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
(∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
...
)
end
function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
$(statement₁, statement₂, ...)
return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> (
NoTangent(),
∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
...
)
end
```

If no type constraints in `f(x₁, x₂, ...)`

within the call to `@scalar_rule`

are provided, each parameter in the resulting `frule`

/`rrule`

definition is given a type constraint of `Number`

. Constraints may also be explicitly be provided to override the `Number`

constraint, e.g. `f(x₁::Complex, x₂)`

, which will constrain `x₁`

to `Complex`

and `x₂`

to `Number`

.

At present this does not support defining for closures/functors. Thus in reverse-mode, the first returned partial, representing the derivative with respect to the function itself, is always `NoTangent()`

. And in forward-mode, the first input to the returned propagator is always ignored.

The result of `f(x₁, x₂, ...)`

is automatically bound to `Ω`

. This allows the primal result to be conveniently referenced (as `Ω`

) within the derivative/setup expressions.

This macro assumes complex functions are holomorphic. In general, for non-holomorphic functions, the `frule`

and `rrule`

must be defined manually.

If the derivative is one, (e.g. for identity functions) `true`

can be used as the most general multiplicative identity.

The `@setup`

argument can be elided if no setup code is need. In other words:

```
@scalar_rule(f(x₁, x₂, ...),
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)
```

is equivalent to:

```
@scalar_rule(f(x₁, x₂, ...),
@setup(nothing),
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)
```

For examples, see ChainRules' `rulesets`

directory.

`ChainRulesCore.NoTangent`

— Type`NoTangent() <: AbstractZero`

This tangent indicates that the derivative does not exist. It is the tangent type for primal types that are not differentiable, such as integers or booleans (when they are not being used to represent floating-point values). The only valid way to perturb such values is to not change them at all. As a consequence, `NoTangent`

is functionally identical to `ZeroTangent()`

, but it provides additional semantic information.

Adding `NoTangent()`

to a primal is generally wrong: gradient-based methods cannot be used to optimize over discrete variables. An optimization package making use of this might want to check for such a case.

This does not indicate that the derivative is not implemented, but rather that mathematically it is not defined.

This mostly shows up as the derivative with respect to dimension, index, or size arguments.

```
function rrule(fill, x, len::Int)
y = fill(x, len)
fill_pullback(ȳ) = (NoTangent(), @thunk(sum(Ȳ)), NoTangent())
return y, fill_pullback
end
```

`ChainRulesCore.ZeroTangent`

— Type`ZeroTangent() <: AbstractZero`

The additive identity for tangents. This is basically the same as `0`

. A derivative of `ZeroTangent()`

does not propagate through the primal function.

`ChainRulesCore.RuleConfig`

— Type`RuleConfig{T}`

The configuration for what rules to use. `T`

: **traits**. This should be a `Union`

of all special traits needed for rules to be allowed to be defined for your AD. If nothing special this should be set to `Union{}`

.

**AD authors** should define a subtype of `RuleConfig`

to use when calling `frule`

/`rrule`

.

**Rule authors** can dispatch on this config when defining rules. For example:

```
# only define rrule for `pop!` on AD systems where mutation is supported.
rrule(::RuleConfig{>:SupportsMutation}, typeof(pop!), ::Vector) = ...
# this definition of map is for any AD that defines a forwards mode
rrule(conf::RuleConfig{>:HasForwardsMode}, typeof(map), ::Vector) = ...
# this definition of map is for any AD that only defines a reverse mode.
# It is not as good as the rrule that can be used if the AD defines a forward-mode as well.
rrule(conf::RuleConfig{>:Union{NoForwardsMode, HasReverseMode}}, typeof(map), ::Vector) = ...
```

For more details see rule configurations and calling back into AD.

`ChainRulesCore.Tangent`

— Type`Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent`

This type represents the tangent for a `struct`

/`NamedTuple`

, or `Tuple`

. `P`

is the the corresponding primal type that this is a tangent for.

`Tangent{P}`

should have fields (technically properties), that match to a subset of the fields of the primal type; and each should be a tangent type matching to the primal type of that field. Fields of the P that are not present in the Tangent are treated as `Zero`

.

`T`

is an implementation detail representing the backing data structure. For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`

. It should not be passed in by user.

For `Tangent`

s of `Tuple`

s, `iterate`

and `getindex`

are overloaded to behave similarly to for a tuple. For `Tangent`

s of `struct`

s, `getproperty`

is overloaded to allow for accessing values via `tangent.fieldname`

. Any fields not explictly present in the `Tangent`

are treated as being set to `ZeroTangent()`

. To make a `Tangent`

have all the fields of the primal the `canonicalize`

function is provided.

`ChainRulesCore.canonicalize`

— Function`canonicalize(tangent::Tangent{P}) -> Tangent{P}`

Return the canonical `Tangent`

for the primal type `P`

. The property names of the returned `Tangent`

match the field names of the primal, and all fields of `P`

not present in the input `tangent`

are explictly set to `ZeroTangent()`

.