Automatic Differentiation using Zygote.jl
Flux re-exports the gradient
from Zygote, and uses this function within train!
to differentiate the model. Zygote has its own documentation, in particular listing some important limitations.
Implicit style
Flux uses primarily what Zygote calls "implicit" gradients, described here in its documentation.
Zygote.gradient
โ Functiongradient(() -> loss(), ps::Params) -> Grads
Gradient with implicit parameters. Takes a zero-argument function, and returns a dictionary-like container, whose keys are arrays x in ps
.
See also withgradient
to keep the value loss()
.
julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];
julia> g = gradient(Params([x, y])) do
sum(x .* y .* z')
end
Grads(...)
julia> g[x]
2ร3 Matrix{Float64}:
7.0 70.0 700.0
8.0 80.0 800.0
julia> haskey(g, z) # only x and y are parameters
false
Zygote.Params
โ TypeParams([A, B])
Container for implicit parameters, used when differentiating a zero-argument funtion () -> loss(A, B)
with respect to A, B
.
Zygote.Grads
โ TypeGrads(...)
Dictionary-like container returned when taking gradients with respect to implicit parameters. For an array W
, appearing within Params([W, A, B...])
, the gradient is g[W]
.
Zygote.jacobian
โ Methodjacobian(loss, ::Params)
Like gradient
with implicit parameters, this method takes a zero-argument function and returns an IdDict
-like object, now containing the Jacobian for each parameter.
Examples
julia> xs = [1 2; 3 4]; ys = [5,7,9];
julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))
Grads(...)
julia> Jxy[ys]
2ร3 Matrix{Int64}:
1 0 0
0 1 0
julia> Jxy[xs]
2ร4 Matrix{Int64}:
2 6 4 8
2 6 4 8
Explicit style
The other way of using Zygote, and using most other AD packages, is to explicitly provide a function and its arguments.
Zygote.gradient
โ Methodgradient(f, args...)
Returns a tuple containing โf/โx
for each argument x
, the derivative (for scalar x
) or the gradient.
f(args...)
must be a real number, see jacobian
for array output.
See also withgradient
to keep the value f(args...)
, and pullback
for value and back-propagator.
julia> gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)
julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)
julia> gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
sum(x.^p .+ y)
end
([14.0, 22.0], 2.0, nothing)
Zygote.withgradient
โ Methodwithgradient(f, args...)
withgradient(f, ::Params)
Returns both the value of the function and the gradient
, as a named tuple.
julia> y, โ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
julia> โ == gradient(/, 1, 2) # explicit mode
true
julia> w = [3.0];
julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode
(val = 9.0, grad = Grads(...))
julia> res.grad[w]
1-element Vector{Float64}:
6.0
Zygote.jacobian
โ Methodjacobian(f, args...) -> Tuple
For each array a โ args
this returns a matrix with Ja[k,i] = โy[k]/โa[i]
where y = f(args...)
is usually a vector. Arrays of higher dimension are treated like vec(a)
, or vec(y)
for output.
For scalar x::Number โ args
, the result is a vector Jx[k] = โy[k]/โx
, while for scalar y
all results have just one row.
With any other argument type, no result is produced, even if gradient
would work.
This reverse-mode Jacobian needs to evaluate the pullback once for each element of y
. Doing so is usually only efficient when length(y)
is small compared to length(a)
, otherwise forward mode is likely to be better.
See also withjacobian
, hessian
, hessian_reverse
.
Examples
julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output
3ร7 Matrix{Int64}:
200 0 0 0 0 0 0
0 400 0 0 0 0 0
0 0 600 0 0 0 0
julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian
([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])
julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
([2 0 โฆ 0 0; 0 4 โฆ 3 0; 0 0 โฆ 0 5], [0, 0, 0])
For arguments of any type except Number
& AbstractArray
, the result is nothing
.
julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str")
([3 0 0; 0 12 0; 0 0 27], nothing)
julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)
julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple
([4 4 4], (6, 1))
Zygote.withgradient
โ Functionwithgradient(f, args...)
withgradient(f, ::Params)
Returns both the value of the function and the gradient
, as a named tuple.
julia> y, โ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
julia> โ == gradient(/, 1, 2) # explicit mode
true
julia> w = [3.0];
julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode
(val = 9.0, grad = Grads(...))
julia> res.grad[w]
1-element Vector{Float64}:
6.0
ChainRules
Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using ChainRules:
ChainRulesCore.ignore_derivatives
โ Functionignore_derivatives(f::Function)
Tells the AD system to ignore the gradients of the wrapped closure. The primal computation (forward pass) is executed normally.
ignore_derivatives() do
value = rand()
push!(collection, value)
end
Using this incorrectly could lead to incorrect gradients. For example, the following function will have zero gradients with respect to its argument:
function wrong_grads(x)
y = ones(3)
ignore_derivatives() do
push!(y, x)
end
return sum(y)
end
ignore_derivatives(x)
Tells the AD system to ignore the gradients of the argument. Can be used to avoid unnecessary computation of gradients.
ignore_derivatives(x) * w
ChainRulesCore.@non_differentiable
โ Macro@non_differentiable(signature_expression)
A helper to make it easier to declare that a method is not differentiable. This is a short-hand for defining an frule
and rrule
that return NoTangent()
for all partials (even for the function sฬelf
-partial itself)
Keyword arguments should not be included.
julia> @non_differentiable Base.:(==)(a, b)
julia> _, pullback = rrule(==, 2.0, 3.0);
julia> pullback(1.0)
(NoTangent(), NoTangent(), NoTangent())
You can place type-constraints in the signature:
julia> @non_differentiable Base.length(xs::Union{Number, Array})
julia> frule((ZeroTangent(), 1), length, [2.0, 3.0])
(2, NoTangent())
This helper macro covers only the simple common cases. It does not support where
-clauses. For these you can declare the rrule
and frule
directly
To manually supply the gradient for one function, you should define a method of rrule
. ChainRules has detailed documentation on how this works.
ChainRulesCore.rrule
โ Functionrrule([::RuleConfig,] f, x...)
Expressing x
as the tuple (xโ, xโ, ...)
and the output tuple of f(x...)
as ฮฉ
, return the tuple:
(ฮฉ, (ฮฉฬโ, ฮฉฬโ, ...) -> (sฬelf, xฬโ, xฬโ, ...))
Where the second return value is the the propagation rule or pullback. It takes in cotangents corresponding to the outputs (xฬโ, xฬโ, ...
), and sฬelf
, the internal values of the function itself (for closures)
If no method matching rrule(f, xs...)
has been defined, then return nothing
.
Examples:
unary input, unary output scalar function:
julia> x = rand();
julia> sinx, sin_pullback = rrule(sin, x);
julia> sinx == sin(x)
true
julia> sin_pullback(1) == (NoTangent(), cos(x))
true
binary input, unary output scalar function:
julia> x, y = rand(2);
julia> hypotxy, hypot_pullback = rrule(hypot, x, y);
julia> hypotxy == hypot(x, y)
true
julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))
true
The optional RuleConfig
option allows specifying rrules only for AD systems that support given features. If not needed, then it can be omitted and the rrule
without it will be hit as a fallback. This is the case for most rules.
See also: frule
, @scalar_rule
, RuleConfig