Automatic Differentiation using Zygote.jl
Flux re-exports the gradient from Zygote, and uses this function within train! to differentiate the model. Zygote has its own documentation, in particular listing some important limitations.
Implicit style
Flux uses primarily what Zygote calls "implicit" gradients, described here in its documentation.
Zygote.gradient โ Functiongradient(() -> loss(), ps::Params) -> GradsGradient with implicit parameters. Takes a zero-argument function, and returns a dictionary-like container, whose keys are arrays x in ps.
See also withgradient to keep the value loss().
julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];
julia> g = gradient(Params([x, y])) do
sum(x .* y .* z')
end
Grads(...)
julia> g[x]
2ร3 Matrix{Float64}:
7.0 70.0 700.0
8.0 80.0 800.0
julia> haskey(g, z) # only x and y are parameters
falseZygote.Params โ TypeParams([A, B])Container for implicit parameters, used when differentiating a zero-argument funtion () -> loss(A, B) with respect to A, B.
Zygote.Grads โ TypeGrads(...)Dictionary-like container returned when taking gradients with respect to implicit parameters. For an array W, appearing within Params([W, A, B...]), the gradient is g[W].
Zygote.jacobian โ Methodjacobian(loss, ::Params)Like gradient with implicit parameters, this method takes a zero-argument function and returns an IdDict-like object, now containing the Jacobian for each parameter.
Examples
julia> xs = [1 2; 3 4]; ys = [5,7,9];
julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))
Grads(...)
julia> Jxy[ys]
2ร3 Matrix{Int64}:
1 0 0
0 1 0
julia> Jxy[xs]
2ร4 Matrix{Int64}:
2 6 4 8
2 6 4 8Explicit style
The other way of using Zygote, and using most other AD packages, is to explicitly provide a function and its arguments.
Zygote.gradient โ Methodgradient(f, args...)Returns a tuple containing โf/โx for each argument x, the derivative (for scalar x) or the gradient.
f(args...) must be a real number, see jacobian for array output.
See also withgradient to keep the value f(args...), and pullback for value and back-propagator.
julia> gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)
julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
([14.0, 22.0, 26.0],)
julia> gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
sum(x.^p .+ y)
end
([14.0, 22.0], 2.0, nothing)Zygote.withgradient โ Methodwithgradient(f, args...)
withgradient(f, ::Params)Returns both the value of the function and the gradient, as a named tuple.
julia> y, โ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
julia> โ == gradient(/, 1, 2) # explicit mode
true
julia> w = [3.0];
julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode
(val = 9.0, grad = Grads(...))
julia> res.grad[w]
1-element Vector{Float64}:
6.0Zygote.jacobian โ Methodjacobian(f, args...) -> TupleFor each array a โ args this returns a matrix with Ja[k,i] = โy[k]/โa[i] where y = f(args...) is usually a vector. Arrays of higher dimension are treated like vec(a), or vec(y) for output.
For scalar x::Number โ args, the result is a vector Jx[k] = โy[k]/โx, while for scalar y all results have just one row.
With any other argument type, no result is produced, even if gradient would work.
This reverse-mode Jacobian needs to evaluate the pullback once for each element of y. Doing so is usually only efficient when length(y) is small compared to length(a), otherwise forward mode is likely to be better.
See also withjacobian, hessian, hessian_reverse.
Examples
julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output
3ร7 Matrix{Int64}:
200 0 0 0 0 0 0
0 400 0 0 0 0 0
0 0 600 0 0 0 0
julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian
([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])
julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
([2 0 โฆ 0 0; 0 4 โฆ 3 0; 0 0 โฆ 0 5], [0, 0, 0])For arguments of any type except Number & AbstractArray, the result is nothing.
julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str")
([3 0 0; 0 12 0; 0 0 27], nothing)
julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)
julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple
([4 4 4], (6, 1))Zygote.withgradient โ Functionwithgradient(f, args...)
withgradient(f, ::Params)Returns both the value of the function and the gradient, as a named tuple.
julia> y, โ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
julia> โ == gradient(/, 1, 2) # explicit mode
true
julia> w = [3.0];
julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode
(val = 9.0, grad = Grads(...))
julia> res.grad[w]
1-element Vector{Float64}:
6.0ChainRules
Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using ChainRules:
ChainRulesCore.ignore_derivatives โ Functionignore_derivatives(f::Function)Tells the AD system to ignore the gradients of the wrapped closure. The primal computation (forward pass) is executed normally.
ignore_derivatives() do
value = rand()
push!(collection, value)
endUsing this incorrectly could lead to incorrect gradients. For example, the following function will have zero gradients with respect to its argument:
function wrong_grads(x)
y = ones(3)
ignore_derivatives() do
push!(y, x)
end
return sum(y)
endignore_derivatives(x)Tells the AD system to ignore the gradients of the argument. Can be used to avoid unnecessary computation of gradients.
ignore_derivatives(x) * wChainRulesCore.@non_differentiable โ Macro@non_differentiable(signature_expression)A helper to make it easier to declare that a method is not differentiable. This is a short-hand for defining an frule and rrule that return NoTangent() for all partials (even for the function sฬelf-partial itself)
Keyword arguments should not be included.
julia> @non_differentiable Base.:(==)(a, b)
julia> _, pullback = rrule(==, 2.0, 3.0);
julia> pullback(1.0)
(NoTangent(), NoTangent(), NoTangent())You can place type-constraints in the signature:
julia> @non_differentiable Base.length(xs::Union{Number, Array})
julia> frule((ZeroTangent(), 1), length, [2.0, 3.0])
(2, NoTangent())This helper macro covers only the simple common cases. It does not support where-clauses. For these you can declare the rrule and frule directly
To manually supply the gradient for one function, you should define a method of rrule. ChainRules has detailed documentation on how this works.
ChainRulesCore.rrule โ Functionrrule([::RuleConfig,] f, x...)Expressing x as the tuple (xโ, xโ, ...) and the output tuple of f(x...) as ฮฉ, return the tuple:
(ฮฉ, (ฮฉฬโ, ฮฉฬโ, ...) -> (sฬelf, xฬโ, xฬโ, ...))Where the second return value is the the propagation rule or pullback. It takes in cotangents corresponding to the outputs (xฬโ, xฬโ, ...), and sฬelf, the internal values of the function itself (for closures)
If no method matching rrule(f, xs...) has been defined, then return nothing.
Examples:
unary input, unary output scalar function:
julia> x = rand();
julia> sinx, sin_pullback = rrule(sin, x);
julia> sinx == sin(x)
true
julia> sin_pullback(1) == (NoTangent(), cos(x))
truebinary input, unary output scalar function:
julia> x, y = rand(2);
julia> hypotxy, hypot_pullback = rrule(hypot, x, y);
julia> hypotxy == hypot(x, y)
true
julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))
trueThe optional RuleConfig option allows specifying rrules only for AD systems that support given features. If not needed, then it can be omitted and the rrule without it will be hit as a fallback. This is the case for most rules.
See also: frule, @scalar_rule, RuleConfig