Utilities
Zygote provides a set of helpful utilities. These are all "user-level" tools – in other words you could have written them easily yourself, but they live in Zygote for convenience.
Zygote.@showgrad
— Macro@showgrad(x) -> x
Much like @show
, but shows the gradient about to accumulate to x
. Useful for debugging gradients.
julia> gradient(2, 3) do a, b
@showgrad(a)*b
end
∂(a) = 3
(3, 2)
Note that the gradient depends on how the output of @showgrad
is used, and is not the overall gradient of the variable a
. For example:
julia> gradient(2) do a
@showgrad(a)*a
end
∂(a) = 2
(4,)
julia> gradient(2, 3) do a, b
@showgrad(a) # not used, so no gradient
a*b
end
∂(a) = nothing
(3, 2)
Zygote.hook
— Functionhook(x̄ -> ..., x) -> x
Gradient hooks. Allows you to apply an arbitrary function to the gradient for x
.
julia> gradient(2, 3) do a, b
hook(ā -> @show(ā), a)*b
end
ā = 3
(3, 2)
julia> gradient(2, 3) do a, b
hook(-, a)*b
end
(-3, 2)
Zygote.dropgrad
— Functiondropgrad(x) -> x
Drop the gradient of x
.
julia> gradient(2, 3) do a, b
dropgrad(a)*b
end
(nothing, 2)
Zygote.hessian
— Functionhessian(f, x)
Construct the Hessian of f
, where x
is a real or real array and f(x)
is a real.
julia> hessian(((a, b),) -> a*b, [2, 3])
2×2 Array{Int64,2}:
0 1
1 0
Zygote.Buffer
— TypeBuffer(xs, ...)
Buffer
is an array-like type which is mutable when taking gradients. You can construct a Buffer
with the same syntax as similar
(e.g. Buffer(xs, 5)
) and then use normal indexing. Finally, use copy
to get back a normal array.
For example:
julia> function vstack(xs)
buf = Buffer(xs, length(xs), 5)
for i = 1:5
buf[:, i] = xs
end
return copy(buf)
end
vstack (generic function with 1 method)
julia> vstack([1, 2, 3])
3×5 Array{Int64,2}:
1 1 1 1 1
2 2 2 2 2
3 3 3 3 3
julia> gradient(x -> sum(vstack(x)), [1, 2, 3])
([5.0, 5.0, 5.0],)
Buffer
is not an AbstractArray
and can't be used for linear algebra operations like matrix multiplication. This prevents it from being captured by pullbacks.
copy
is a semantic copy, but does not allocate memory. Instead the Buffer
is made immutable after copying.
Zygote.forwarddiff
— Functionforwarddiff(f, x) -> f(x)
Runs f(x)
as usual, but instructs Zygote to differentiate f
using forward mode, rather than the usual reverse mode.
Forward mode takes time linear in length(x)
but only has constant memory overhead, and is very efficient for scalars, so in some cases this can be a useful optimisation.
julia> function pow(x, n)
r = one(x)
for i = 1:n
r *= x
end
return r
end
pow (generic function with 1 method)
julia> gradient(5) do x
forwarddiff(x) do x
pow(x, 2)
end
end
(10,)
Note that the function f
will drop gradients for any closed-over values.
julia> gradient(2, 3) do a, b
forwarddiff(a) do a
a*b
end
end
(3, nothing)
This can be rewritten by explicitly passing through b
, i.e.
gradient(2, 3) do a, b
forwarddiff([a, b]) do (a, b)
a*b
end
end
Zygote.ignore
— Functionignore() do
...
end
Tell Zygote to ignore a block of code. Everything inside the do
block will run on the forward pass as normal, but Zygote won't try to differentiate it at all. This can be useful for e.g. code that does logging of the forward pass.
Obviously, you run the risk of incorrect gradients if you use this incorrectly.
Params
and Grads
can be copied to and from arrays using the copy!
function.