Flat vs. Nested Structures

A Flux model is a nested structure, with parameters stored within many layers. Sometimes you may want a flat representation of them, to interact with functions expecting just one vector. This is provided by destructure:

julia> model = Chain(Dense(2=>1, tanh), Dense(1=>1))
Chain(
  Dense(2 => 1, tanh),                  # 3 parameters
  Dense(1 => 1),                        # 2 parameters
)                   # Total: 4 arrays, 5 parameters, 276 bytes.

julia> flat, rebuild = Flux.destructure(model)
(Float32[0.863101, 1.2454957, 0.0, -1.6345707, 0.0], Restructure(Chain, ..., 5))

julia> rebuild(zeros(5))  # same structure, new parameters
Chain(
  Dense(2 => 1, tanh),                  # 3 parameters  (all zero)
  Dense(1 => 1),                        # 2 parameters  (all zero)
)                   # Total: 4 arrays, 5 parameters, 276 bytes.

Both destructure and the Restructure function can be used within gradient computations. For instance, this computes the Hessian ∂²L/∂θᵢ∂θⱼ of some loss function, with respect to all parameters of the Flux model. The resulting matrix has off-diagonal entries, which cannot really be expressed in a nested structure:

julia> x = rand(Float32, 2, 16);

julia> grad = gradient(m -> sum(abs2, m(x)), model)  # nested gradient
((layers = ((weight = Float32[10.339018 11.379145], bias = Float32[22.845667], σ = nothing), (weight = Float32[-29.565302;;], bias = Float32[-37.644184], σ = nothing)),),)

julia> function loss(v::Vector)
         m = rebuild(v)
         y = m(x)
         sum(abs2, y)
       end;

julia> gradient(loss, flat)  # flat gradient, same numbers
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184],)

julia> Zygote.hessian(loss, flat)  # second derivative
5×5 Matrix{Float32}:
  -7.13131   -5.54714  -11.1393  -12.6504   -8.13492
  -5.54714   -7.11092  -11.0208  -13.9231   -9.36316
 -11.1393   -11.0208   -13.7126  -27.9531  -22.741
 -12.6504   -13.9231   -27.9531   18.0875   23.03
  -8.13492   -9.36316  -22.741    23.03     32.0

julia> Flux.destructure(grad)  # acts on non-models, too
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5))

In order to collect all parameters of a model into a list instead, you can use the trainables function:

julia> Flux.trainables(model)
5-element Vector{AbstractArray}:
  [0.863101 1.2454957]
  [0.0]
  [1.290355429422727;;]
  [0.0]

Any mutation of the elements of the resulting list will affect the model's parameters.

All Parameters

The functions destructure and trainables live in Optimisers.jl.

Optimisers.destructureFunction
destructure(model) -> vector, reconstructor

Copies all trainable, isnumeric parameters in the model to a vector, and returns also a function which reverses this transformation. Differentiable.

Example

julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))

julia> re([3, 5, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))

If model contains various number types, they are promoted to make vector, and are usually restored by Restructure. Such restoration follows the rules of ChainRulesCore.ProjectTo, and thus will restore floating point precision, but will permit more exotic numbers like ForwardDiff.Dual.

If model contains only GPU arrays, then vector will also live on the GPU. At present, a mixture of GPU and ordinary CPU arrays is undefined behaviour.

source
Optimisers.trainableFunction
trainable(x::Layer) -> NamedTuple

This may be overloaded to make optimisers ignore some fields of every Layer, which would otherwise contain trainable parameters.

Warning

This is very rarely required. Fields of struct Layer which contain functions, or integers like sizes, are always ignored anyway. Overloading trainable is only necessary when some arrays of numbers are to be optimised, and some arrays of numbers are not.

The default is Functors.children(x), usually a NamedTuple of all fields, and trainable(x) must contain a subset of these.

source
Optimisers.trainablesFunction
trainables(x, path = false)

Return an iterable over all the trainable parameters in x, that is all the numerical arrays (see isnumeric) which are reachable through trainable.

Parameters appearing multiple times in the model (tied weights) will be present only once in the output.

If path = false, the output is a list of numerical arrays.

If path = true, the output is a list of (KeyPath, AbstractArray) pairs, where KeyPath is a type representing the path to the array in the original structure.

See also destructure for a similar operation that returns a single flat vector instead.

Examples

julia> struct MyLayer
         w
         b
       end

julia> Functors.@functor MyLayer

julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example

julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]);

julia> trainables(x)
1-element Vector{AbstractArray}:
 [1.0, 2.0, 3.0]

 julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);

 julia> trainables(x) # collects nested parameters
 2-element Vector{AbstractArray}:
 [1.0, 2.0]
 [3.0]
julia> x = (a = [1.0,2.0], b = (Dict("c" => [3.0, 4.0], "d" => 5.0), [6.0,7.0]));

julia> for (kp, y) in trainables(x, path = true)
           println(kp, " => ", y)
       end
KeyPath(:a,) => [1.0, 2.0]
KeyPath(:b, 1, "c") => [3.0, 4.0]
KeyPath(:b, 2) => [6.0, 7.0]

julia> getkeypath(x, KeyPath(:b, 1, "c"))
2-element Vector{Float64}:
 3.0
 4.0
source
Optimisers.isnumericFunction
isnumeric(x) -> Bool

Returns true on any parameter to be adjusted by Optimisers.jl, namely arrays of non-integer numbers. Returns false on all other types.

Requires also that Functors.isleaf(x) == true, to focus on e.g. the parent of a transposed matrix, not the wrapper.

source
Flux.paramsFunction
params(model)

Returns a Zygote.Params object containing all parameter arrays from the model. This is deprecated! This function was the cornerstone of how Flux used Zygote's implicit mode gradients, but since Flux 0.13 we use explicit mode gradient(m -> loss(m, x, y), model) instead. To collect all the parameter arrays for other purposes, use Flux.trainables(model).

source

All Layers

Another kind of flat view of a nested model is provided by the modules command. This extracts a list of all layers:

Flux.modulesFunction
modules(m)

Return an iterator over non-leaf objects that can be reached by recursing m over the children given by Functors.functor.

Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters (e.g. the weights but not the biases).

Examples

julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu));

julia> m2 = Chain(m1, Dense(64, 10))
Chain(
  Chain(
    Dense(784 => 64),                   # 50_240 parameters
    BatchNorm(64, relu),                # 128 parameters, plus 128
  ),
  Dense(64 => 10),                      # 650 parameters
)         # Total: 6 trainable arrays, 51_018 parameters,
          # plus 2 non-trainable, 128 parameters, summarysize 200.211 KiB.

julia> Flux.modules(m2)
7-element Vector{Any}:
 Chain(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))  # 51_018 parameters, plus 128 non-trainable
 (Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))
 Chain(Dense(784 => 64), BatchNorm(64, relu))  # 50_368 parameters, plus 128 non-trainable
 (Dense(784 => 64), BatchNorm(64, relu))
 Dense(784 => 64)    # 50_240 parameters
 BatchNorm(64, relu)  # 128 parameters, plus 128 non-trainable
 Dense(64 => 10)     # 650 parameters

julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)

julia> L2(m2) isa Float32
true
source

Save and Load

Flux.stateFunction
state(x)

Return an object with the same nested structure as x according to Functors.children, but made only of basic containers (e.g. named tuples, tuples, arrays, and dictionaries).

Besides trainable and non-trainable arrays, the state will contain leaf nodes that are not arrays, such as numbers, symbols, strings, and nothing values. The leaf types that end up in the state could increase in the future.

This method is particularly useful for saving and loading models, since the state contain only simple data types that can be easily serialized.

The state can be passed to loadmodel! to restore the model.

Examples

Copy the state into another model

julia> m1 = Chain(Dense(1, 2, tanh; init=ones), Dense(2, 1; init=ones));

julia> s = Flux.state(m1)
(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0], σ = ()), (weight = [1.0 1.0], bias = [0.0], σ = ())),)

julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1; bias=false));  # weights are random numbers

julia> Flux.loadmodel!(m2, s);

julia> m2[1].weight   # now the weights of m2 are the same as m1
2×1 Matrix{Float32}:
 1.0
 1.0

julia> Flux.state(trainmode!(Dropout(0.2)))  # contains p & activity, but not RNG state
(p = 0.2, dims = (), active = true, rng = ())

julia> Flux.state(BatchNorm(1))  # contains non-trainable arrays μ, σ²
(λ = (), β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1)

Save and load with BSON

julia> using BSON

julia> BSON.@save "checkpoint.bson" model_state = s

julia> Flux.loadmodel!(m2, BSON.load("checkpoint.bson")[:model_state])

Save and load with JLD2

julia> using JLD2

julia> JLD2.jldsave("checkpoint.jld2", model_state = s)

julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state"))
source
Flux.loadmodel!Function
loadmodel!(dst, src)

Copy all the parameters (trainable and non-trainable) from src into dst.

Recursively walks dst and src together using Functors.children, and calling copyto! on parameter arrays or throwing an error when there is a mismatch. Non-array elements (such as activation functions) are not copied and need not match. Zero bias vectors and bias=false are considered equivalent (see extended help for more details).

See also Flux.state.

Examples

julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0]))
Chain(
  Dense(5 => 2, tanh),                  # 12 parameters
  Dense(2 => 1),                        # 3 parameters
)                   # Total: 4 arrays, 15 parameters, 316 bytes.

julia> dst[1].weight ≈ ones(2, 5)  # by construction
true

julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false));

julia> Flux.loadmodel!(dst, src);

julia> dst[1].weight ≈ ones(2, 5)  # values changed
false

julia> iszero(dst[2].bias)
true

Extended help

Throws an error when:

  • dst and src do not share the same fields (at any level)
  • the sizes of leaf nodes are mismatched between dst and src
  • copying non-array values to/from an array parameter (except inactive parameters described below)
  • dst is a "tied" parameter (i.e. refers to another parameter) and loaded into multiple times with mismatched source values

Inactive parameters can be encoded by using the boolean value false instead of an array. If dst == false and src is an all-zero array, no error will be raised (and no values copied); however, attempting to copy a non-zero array to an inactive parameter will throw an error. Likewise, copying a src value of false to any dst array is valid, but copying a src value of true will error.

source

KeyPath

Functors.KeyPathType
KeyPath(keys...)

A type for representing a path of keys to a value in a nested structure. Can be constructed with a sequence of keys, or by concatenating other KeyPaths. Keys can be of type Symbol, String, Int, or CartesianIndex.

For custom types, access through symbol keys is assumed to be done with getproperty. For consistency, the method Base.propertynames is used to get the viable property names.

For string, integer, and cartesian index keys, the access is done with getindex instead.

See also getkeypath, haskeypath.

Examples

julia> kp = KeyPath(:b, 3)
KeyPath(:b, 3)

julia> KeyPath(:a, kp, :c, 4) # construct mixing keys and keypaths
KeyPath(:a, :b, 3, :c, 4)

julia> struct T
           a
           b
       end

julia> function Base.getproperty(x::T, k::Symbol)
            if k in fieldnames(T)
                return getfield(x, k)
            elseif k === :ab
                return "ab"
            else        
                error()
            end
        end;

julia> Base.propertynames(::T) = (:a, :b, :ab);

julia> x = T(3, Dict(:c => 4, :d => 5));

julia> getkeypath(x, KeyPath(:ab)) # equivalent to x.ab
"ab"

julia> getkeypath(x, KeyPath(:b, :c)) # equivalent to (x.b)[:c]
4
source
Functors.getkeypathFunction
getkeypath(x, kp::KeyPath)

Return the value in x at the path kp.

See also KeyPath, haskeypath, and setkeypath!.

Examples

julia> x = Dict(:a => 3, :b => Dict(:c => 4, "d" => [5, 6, 7]))
Dict{Symbol, Any} with 2 entries:
  :a => 3
  :b => Dict{Any, Any}(:c=>4, "d"=>[5, 6, 7])

julia> getkeypath(x, KeyPath(:b, "d", 2))
6
source
Functors.haskeypathFunction
haskeypath(x, kp::KeyPath)

Return true if x has a value at the path kp.

See also KeyPath, getkeypath, and setkeypath!.

Examples

julia> x = Dict(:a => 3, :b => Dict(:c => 4, "d" => [5, 6, 7]))
Dict{Symbol, Any} with 2 entries:
  :a => 3
  :b => Dict{Any, Any}(:c=>4, "d"=>[5, 6, 7])

julia> haskeypath(x, KeyPath(:a))
true

julia> haskeypath(x, KeyPath(:b, "d", 1))
true

julia> haskeypath(x, KeyPath(:b, "d", 4))
false
source