# Flat vs. Nested Structures

A Flux model is a nested structure, with parameters stored within many layers. Sometimes you may want a flat representation of them, to interact with functions expecting just one vector. This is provided by `destructure`

:

```
julia> model = Chain(Dense(2=>1, tanh), Dense(1=>1))
Chain(
Dense(2 => 1, tanh), # 3 parameters
Dense(1 => 1), # 2 parameters
) # Total: 4 arrays, 5 parameters, 276 bytes.
julia> flat, rebuild = Flux.destructure(model)
(Float32[0.863101, 1.2454957, 0.0, -1.6345707, 0.0], Restructure(Chain, ..., 5))
julia> rebuild(zeros(5)) # same structure, new parameters
Chain(
Dense(2 => 1, tanh), # 3 parameters (all zero)
Dense(1 => 1), # 2 parameters (all zero)
) # Total: 4 arrays, 5 parameters, 276 bytes.
```

Both `destructure`

and the `Restructure`

function can be used within gradient computations. For instance, this computes the Hessian `∂²L/∂θᵢ∂θⱼ`

of some loss function, with respect to all parameters of the Flux model. The resulting matrix has off-diagonal entries, which cannot really be expressed in a nested structure:

```
julia> x = rand(Float32, 2, 16);
julia> grad = gradient(m -> sum(abs2, m(x)), model) # nested gradient
((layers = ((weight = Float32[10.339018 11.379145], bias = Float32[22.845667], σ = nothing), (weight = Float32[-29.565302;;], bias = Float32[-37.644184], σ = nothing)),),)
julia> function loss(v::Vector)
m = rebuild(v)
y = m(x)
sum(abs2, y)
end;
julia> gradient(loss, flat) # flat gradient, same numbers
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184],)
julia> Zygote.hessian(loss, flat) # second derivative
5×5 Matrix{Float32}:
-7.13131 -5.54714 -11.1393 -12.6504 -8.13492
-5.54714 -7.11092 -11.0208 -13.9231 -9.36316
-11.1393 -11.0208 -13.7126 -27.9531 -22.741
-12.6504 -13.9231 -27.9531 18.0875 23.03
-8.13492 -9.36316 -22.741 23.03 32.0
julia> Flux.destructure(grad) # acts on non-models, too
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5))
```

Old versions of Flux had an entirely different implementation of `destructure`

, which had many bugs (and almost no tests). Many comments online still refer to that now-deleted function, or to memories of it.

### All Parameters

The function `destructure`

now lives in `Optimisers.jl`

. (Be warned this package is unrelated to the `Flux.Optimisers`

sub-module! The confusion is temporary.)

`Optimisers.destructure`

— Function`destructure(model) -> vector, reconstructor`

Copies all `trainable`

, `isnumeric`

parameters in the model to a vector, and returns also a function which reverses this transformation. Differentiable.

**Example**

```
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))
julia> re([3, 5, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))
```

If `model`

contains various number types, they are promoted to make `vector`

, and are usually restored by `Restructure`

. Such restoration follows the rules of `ChainRulesCore.ProjectTo`

, and thus will restore floating point precision, but will permit more exotic numbers like `ForwardDiff.Dual`

.

If `model`

contains only GPU arrays, then `vector`

will also live on the GPU. At present, a mixture of GPU and ordinary CPU arrays is undefined behaviour.

`Optimisers.trainable`

— Function`trainable(x::Layer) -> NamedTuple`

This may be overloaded to make optimisers ignore some fields of every `Layer`

, which would otherwise contain trainable parameters.

This is very rarely required. Fields of `struct Layer`

which contain functions, or integers like sizes, are always ignored anyway. Overloading `trainable`

is only necessary when some arrays of numbers are to be optimised, and some arrays of numbers are not.

The default is `Functors.children(x)`

, usually a NamedTuple of all fields, and `trainable(x)`

must contain a subset of these.

`Optimisers.isnumeric`

— Function`isnumeric(x) -> Bool`

Returns `true`

on any parameter to be adjusted by Optimisers.jl, namely arrays of non-integer numbers. Returns `false`

on all other types.

Requires also that `Functors.isleaf(x) == true`

, to focus on e.g. the parent of a transposed matrix, not the wrapper.

### All Layers

Another kind of flat view of a nested model is provided by the `modules`

command. This extracts a list of all layers:

`Flux.modules`

— Function`modules(m)`

Return an iterator over non-leaf objects that can be reached by recursing `m`

over the children given by `functor`

.

Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters (e.g. the weights but not the biases).

**Examples**

```
julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu));
julia> m2 = Chain(m1, Dense(64, 10))
Chain(
Chain(
Dense(784 => 64), # 50_240 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
),
Dense(64 => 10), # 650 parameters
) # Total: 6 trainable arrays, 51_018 parameters,
# plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB.
julia> Flux.modules(m2)
7-element Vector{Any}:
Chain(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10)) # 51_018 parameters, plus 128 non-trainable
(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))
Chain(Dense(784 => 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable
(Dense(784 => 64), BatchNorm(64, relu))
Dense(784 => 64) # 50_240 parameters
BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable
Dense(64 => 10) # 650 parameters
julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
julia> L2(m2) isa Float32
true
```

### Save and Load

`Flux.state`

— Function`state(x)`

Return an object with the same nested structure as `x`

according to `Functors.children`

, but made only of basic containers (e.g. named tuples, tuples, arrays, and dictionaries).

Besides trainable and non-trainable arrays, the state will contain leaf nodes that are not arrays, such as numbers, symbols, strings, and nothing values. The leaf types that end up in the state could increase in the future.

This method is particularly useful for saving and loading models, since the state contain only simple data types that can be easily serialized.

The state can be passed to `loadmodel!`

to restore the model.

**Examples**

**Copy the state into another model**

```
julia> m1 = Chain(Dense(1, 2, tanh; init=ones), Dense(2, 1; init=ones));
julia> s = Flux.state(m1)
(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0], σ = ()), (weight = [1.0 1.0], bias = [0.0], σ = ())),)
julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1; bias=false)); # weights are random numbers
julia> Flux.loadmodel!(m2, s);
julia> m2[1].weight # now the weights of m2 are the same as m1
2×1 Matrix{Float32}:
1.0
1.0
julia> Flux.state(trainmode!(Dropout(0.2))) # contains p & activity, but not RNG state
(p = 0.2, dims = (), active = true, rng = ())
julia> Flux.state(BatchNorm(1)) # contains non-trainable arrays μ, σ²
(λ = (), β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1)
```

**Save and load with BSON**

```
julia> using BSON
julia> BSON.@save "checkpoint.bson" model_state = s
julia> Flux.loadmodel!(m2, BSON.load("checkpoint.bson")[:model_state])
```

**Save and load with JLD2**

```
julia> using JLD2
julia> JLD2.jldsave("checkpoint.jld2", model_state = s)
julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state"))
```

`Flux.loadmodel!`

— Function`loadmodel!(dst, src)`

Copy all the parameters (trainable and non-trainable) from `src`

into `dst`

.

Recursively walks `dst`

and `src`

together using `Functors.children`

, and calling `copyto!`

on parameter arrays or throwing an error when there is a mismatch. Non-array elements (such as activation functions) are not copied and need not match. Zero bias vectors and `bias=false`

are considered equivalent (see extended help for more details).

See also `Flux.state`

.

**Examples**

```
julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0]))
Chain(
Dense(5 => 2, tanh), # 12 parameters
Dense(2 => 1), # 3 parameters
) # Total: 4 arrays, 15 parameters, 316 bytes.
julia> dst[1].weight ≈ ones(2, 5) # by construction
true
julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false));
julia> Flux.loadmodel!(dst, src);
julia> dst[1].weight ≈ ones(2, 5) # values changed
false
julia> iszero(dst[2].bias)
true
```

**Extended help**

Throws an error when:

`dst`

and`src`

do not share the same fields (at any level)- the sizes of leaf nodes are mismatched between
`dst`

and`src`

- copying non-array values to/from an array parameter (except inactive parameters described below)
`dst`

is a "tied" parameter (i.e. refers to another parameter) and loaded into multiple times with mismatched source values

Inactive parameters can be encoded by using the boolean value `false`

instead of an array. If `dst == false`

and `src`

is an all-zero array, no error will be raised (and no values copied); however, attempting to copy a non-zero array to an inactive parameter will throw an error. Likewise, copying a `src`

value of `false`

to any `dst`

array is valid, but copying a `src`

value of `true`

will error.