Flat vs. Nested Structures
A Flux model is a nested structure, with parameters stored within many layers. Sometimes you may want a flat representation of them, to interact with functions expecting just one vector. This is provided by destructure
:
julia> model = Chain(Dense(2=>1, tanh), Dense(1=>1))
Chain(
Dense(2 => 1, tanh), # 3 parameters
Dense(1 => 1), # 2 parameters
) # Total: 4 arrays, 5 parameters, 276 bytes.
julia> flat, rebuild = Flux.destructure(model)
(Float32[0.863101, 1.2454957, 0.0, -1.6345707, 0.0], Restructure(Chain, ..., 5))
julia> rebuild(zeros(5)) # same structure, new parameters
Chain(
Dense(2 => 1, tanh), # 3 parameters (all zero)
Dense(1 => 1), # 2 parameters (all zero)
) # Total: 4 arrays, 5 parameters, 276 bytes.
Both destructure
and the Restructure
function can be used within gradient computations. For instance, this computes the Hessian βΒ²L/βΞΈα΅’βΞΈβ±Ό
of some loss function, with respect to all parameters of the Flux model. The resulting matrix has off-diagonal entries, which cannot really be expressed in a nested structure:
julia> x = rand(Float32, 2, 16);
julia> grad = gradient(m -> sum(abs2, m(x)), model) # nested gradient
((layers = ((weight = Float32[10.339018 11.379145], bias = Float32[22.845667], Ο = nothing), (weight = Float32[-29.565302;;], bias = Float32[-37.644184], Ο = nothing)),),)
julia> function loss(v::Vector)
m = rebuild(v)
y = m(x)
sum(abs2, y)
end;
julia> gradient(loss, flat) # flat gradient, same numbers
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184],)
julia> Zygote.hessian(loss, flat) # second derivative
5Γ5 Matrix{Float32}:
-7.13131 -5.54714 -11.1393 -12.6504 -8.13492
-5.54714 -7.11092 -11.0208 -13.9231 -9.36316
-11.1393 -11.0208 -13.7126 -27.9531 -22.741
-12.6504 -13.9231 -27.9531 18.0875 23.03
-8.13492 -9.36316 -22.741 23.03 32.0
julia> Flux.destructure(grad) # acts on non-models, too
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5))
All Parameters
The function destructure
now lives in Optimisers.jl
. (Be warned this package is unrelated to the Flux.Optimisers
sub-module! The confusion is temporary.)
Optimisers.destructure
β Functiondestructure(model) -> vector, reconstructor
Copies all trainable
, isnumeric
parameters in the model to a vector, and returns also a function which reverses this transformation. Differentiable.
Example
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))
julia> re([3, 5, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))
If model
contains various number types, they are promoted to make vector
, and are usually restored by Restructure
. Such restoration follows the rules of ChainRulesCore.ProjectTo
, and thus will restore floating point precision, but will permit more exotic numbers like ForwardDiff.Dual
.
If model
contains only GPU arrays, then vector
will also live on the GPU. At present, a mixture of GPU and ordinary CPU arrays is undefined behaviour.
Optimisers.trainable
β Functiontrainable(x::Layer) -> NamedTuple
This should be overloaded to make optimisers ignore some fields of every Layer
, which would otherwise contain trainable parameters. (Elements such as functions and sizes are always ignored.)
The default is Functors.children(x)
, usually a NamedTuple of all fields, and trainable(x)
must contain a subset of these.
Optimisers.isnumeric
β Functionisnumeric(x) -> Bool
Returns true
on any parameter to be adjusted by Optimisers.jl, namely arrays of non-integer numbers. Returns false
on all other types.
Requires also that Functors.isleaf(x) == true
, to focus on e.g. the parent of a transposed matrix, not the wrapper.
All Layers
Another kind of flat view of a nested model is provided by the modules
command. This extracts a list of all layers:
Flux.modules
β Functionmodules(m)
Return an iterator over non-leaf objects that can be reached by recursing m
over the children given by functor
.
Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters (e.g. the weights but not the biases).
Examples
julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu));
julia> m2 = Chain(m1, Dense(64, 10))
Chain(
Chain(
Dense(784 => 64), # 50_240 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
),
Dense(64 => 10), # 650 parameters
) # Total: 6 trainable arrays, 51_018 parameters,
# plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB.
julia> Flux.modules(m2)
7-element Vector{Any}:
Chain(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10)) # 51_018 parameters, plus 128 non-trainable
(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))
Chain(Dense(784 => 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable
(Dense(784 => 64), BatchNorm(64, relu))
Dense(784 => 64) # 50_240 parameters
BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable
Dense(64 => 10) # 650 parameters
julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
julia> L2(m2) isa Float32
true