Flat vs. Nested Structures

A Flux model is a nested structure, with parameters stored within many layers. Sometimes you may want a flat representation of them, to interact with functions expecting just one vector. This is provided by destructure:

julia> model = Chain(Dense(2=>1, tanh), Dense(1=>1))
Chain(
  Dense(2 => 1, tanh),                  # 3 parameters
  Dense(1 => 1),                        # 2 parameters
)                   # Total: 4 arrays, 5 parameters, 276 bytes.

julia> flat, rebuild = Flux.destructure(model)
(Float32[0.863101, 1.2454957, 0.0, -1.6345707, 0.0], Restructure(Chain, ..., 5))

julia> rebuild(zeros(5))  # same structure, new parameters
Chain(
  Dense(2 => 1, tanh),                  # 3 parameters  (all zero)
  Dense(1 => 1),                        # 2 parameters  (all zero)
)                   # Total: 4 arrays, 5 parameters, 276 bytes.

Both destructure and the Restructure function can be used within gradient computations. For instance, this computes the Hessian βˆ‚Β²L/βˆ‚ΞΈα΅’βˆ‚ΞΈβ±Ό of some loss function, with respect to all parameters of the Flux model. The resulting matrix has off-diagonal entries, which cannot really be expressed in a nested structure:

julia> x = rand(Float32, 2, 16);

julia> grad = gradient(m -> sum(abs2, m(x)), model)  # nested gradient
((layers = ((weight = Float32[10.339018 11.379145], bias = Float32[22.845667], Οƒ = nothing), (weight = Float32[-29.565302;;], bias = Float32[-37.644184], Οƒ = nothing)),),)

julia> function loss(v::Vector)
         m = rebuild(v)
         y = m(x)
         sum(abs2, y)
       end;

julia> gradient(loss, flat)  # flat gradient, same numbers
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184],)

julia> Zygote.hessian(loss, flat)  # second derivative
5Γ—5 Matrix{Float32}:
  -7.13131   -5.54714  -11.1393  -12.6504   -8.13492
  -5.54714   -7.11092  -11.0208  -13.9231   -9.36316
 -11.1393   -11.0208   -13.7126  -27.9531  -22.741
 -12.6504   -13.9231   -27.9531   18.0875   23.03
  -8.13492   -9.36316  -22.741    23.03     32.0

julia> Flux.destructure(grad)  # acts on non-models, too
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5))

All Parameters

The function destructure now lives in Optimisers.jl. (Be warned this package is unrelated to the Flux.Optimisers sub-module! The confusion is temporary.)

Optimisers.destructure β€” Function
destructure(model) -> vector, reconstructor

Copies all trainable, isnumeric parameters in the model to a vector, and returns also a function which reverses this transformation. Differentiable.

Example

julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))

julia> re([3, 5, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))

If model contains various number types, they are promoted to make vector, and are usually restored by Restructure. Such restoration follows the rules of ChainRulesCore.ProjectTo, and thus will restore floating point precision, but will permit more exotic numbers like ForwardDiff.Dual.

If model contains only GPU arrays, then vector will also live on the GPU. At present, a mixture of GPU and ordinary CPU arrays is undefined behaviour.

Optimisers.trainable β€” Function
trainable(x::Layer) -> NamedTuple

This should be overloaded to make optimisers ignore some fields of every Layer, which would otherwise contain trainable parameters. (Elements such as functions and sizes are always ignored.)

The default is Functors.children(x), usually a NamedTuple of all fields, and trainable(x) must contain a subset of these.

Optimisers.isnumeric β€” Function
isnumeric(x) -> Bool

Returns true on any parameter to be adjusted by Optimisers.jl, namely arrays of non-integer numbers. Returns false on all other types.

Requires also that Functors.isleaf(x) == true, to focus on e.g. the parent of a transposed matrix, not the wrapper.

All Layers

Another kind of flat view of a nested model is provided by the modules command. This extracts a list of all layers:

Flux.modules β€” Function
modules(m)

Return an iterator over non-leaf objects that can be reached by recursing m over the children given by functor.

Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters (e.g. the weights but not the biases).

Examples

julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu));

julia> m2 = Chain(m1, Dense(64, 10))
Chain(
  Chain(
    Dense(784 => 64),                   # 50_240 parameters
    BatchNorm(64, relu),                # 128 parameters, plus 128
  ),
  Dense(64 => 10),                      # 650 parameters
)         # Total: 6 trainable arrays, 51_018 parameters,
          # plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB.

julia> Flux.modules(m2)
7-element Vector{Any}:
 Chain(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))  # 51_018 parameters, plus 128 non-trainable
 (Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))
 Chain(Dense(784 => 64), BatchNorm(64, relu))  # 50_368 parameters, plus 128 non-trainable
 (Dense(784 => 64), BatchNorm(64, relu))
 Dense(784 => 64)    # 50_240 parameters
 BatchNorm(64, relu)  # 128 parameters, plus 128 non-trainable
 Dense(64 => 10)     # 650 parameters

julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)

julia> L2(m2) isa Float32
true
source