Recursive transformations from Functors.jl

Flux models are deeply nested structures, and Functors.jl provides tools needed to explore such objects, apply functions to the parameters they contain (e.g. for moving them to gpu), and re-build them.

Flux ≤ v0.14

All layers were previously defined with the Functors.@functor macro. This still works, but it is recommended that you use the new Flux.@layer macro instead. Both allow Flux.setup to see the parameters inside, and gpu to move them to the GPU, but Flux.@layer also overloads printing, and offers a way to define trainable at the same time.

Functors v0.5

With Functors.jl v0.5, which is required by Flux v0.15 and later, every custom type is a functor by default. This means that applying Flux.@layer to a type is no longer strictly necessary, but it is still recommended for addictional features like pretty-printing.

Functors.jl has its own notes on basic usage for more details. Additionally, the Advanced Model Building and Customisation page covers the use cases of Functors in greater details.

Flux.@layerMacro
@layer [showtype] MyModel [trainable=(field1,...)]

This macro adds convenience functionality to a custom type to serve as a neural network layer, as a module, or as an entire model.

The optional keyword trainable allows you to specify which fields of your model can be trained, instead of assuming all fieldnames(MyModel) to trainable. Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. This can be also be done by defining trainable(::MyModel) for your type.

The macro also handles overloads of the 3-arg show(::IO, ::MIME"text/plain", ::MyModel) for pretty printing. The optional argument showtype can take any of the following values:

  • :expand (default): This will expand the representation of container types like Chain, while maintaining a compat representation of types like Dense containing only arrays.
  • :noexpand: This is to be used in case your type contains other layers but you want to keep the representation simple.
  • :ignore: To opt out of the pretty printing.

You probably still want to define 2-arg show(::IO, ::MyModel), the macro does not touch this.

Note that re-running the macro with different options may not remove all methods, you will need to restart.

Example

julia> struct Trio; a; b; c end

julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4))
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))

julia> Flux.@layer Trio

julia> tri  # now the layer is printed like Chain
Trio(
  Dense(2 => 1, tanh),                  # 3 parameters
  Dense(1 => 1; bias=false),            # 1 parameters
  Dropout(0.4),
)                   # Total: 3 arrays, 4 parameters, 240 bytes.

julia> Flux.@layer :noexpand Trio trainable=(a,b)

julia> tri  # now the layer is printed compactly
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))  # 4 parameters

julia> opt_state = Flux.setup(Adam(), tri); # `c` is not in the optimizer state

The macro also adds methods to make using Flux with Enzyme easier.

  • Duplicated(m::Layer) allocates a copy for the gradient (initially zero).
  • This is made callable, (m::Duplicated{<:Layer})(x...) = m.val(x...)
  • Pretty printing for show(io, mime, ::Duplicated{<:Layer})
source
Functors.@functorMacro
@functor T
@functor T (x,)

Adds methods to functor allowing recursion into objects of type T, and reconstruction. Assumes that T has a constructor accepting all of its fields, which is true unless you have provided an inner constructor which does not.

By default all fields of T are considered children; this can be restricted be restructed by providing a tuple of field names.

Examples

julia> struct Foo; x; y; end

julia> Functors.children(Foo(1,2))
(x = 1, y = 2)

julia> _, re = Functors.functor(Foo(1,2));

julia> re((10, 20))
Foo(10, 20)

julia> @functor Foo # same as before, nothing changes

julia> struct TwoThirds a; b; c; end

julia> @functor TwoThirds (a, c)

julia> ch2, re3 = Functors.functor(TwoThirds(10,20,30));

julia> ch2
(a = 10, c = 30)

julia> re3(("ten", "thirty"))
TwoThirds("ten", 20, "thirty")

julia> fmap(x -> 10x, TwoThirds(Foo(1,2), Foo(3,4), 56))
TwoThirds(Foo(10, 20), Foo(3, 4), 560)
source
Functors.fmapFunction
fmap(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalk(), [prune])

A structure and type preserving map.

By default it transforms every leaf node (identified by exclude, default isleaf) by applying f, and otherwise traverses x recursively using functor. Optionally, it may also be associated with objects ys with the same tree structure. In that case, f is applied to the corresponding leaf nodes in x and ys.

See also fmap_with_path and fmapstructure.

Examples

julia> fmap(string, (x=1, y=(2, 3)))
(x = "1", y = ("2", "3"))

julia> nt = (a = [1,2], b = [23, (45,), (x=6//7, y=())], c = [8,9]);

julia> fmap(println, nt)
[1, 2]
23
45
6//7
()
[8, 9]
(a = nothing, b = Any[nothing, (nothing,), (x = nothing, y = nothing)], c = nothing)

julia> fmap(println, nt; exclude = x -> x isa Array)
[1, 2]
Any[23, (45,), (x = 6//7, y = ())]
[8, 9]
(a = nothing, b = nothing, c = nothing)

julia> twice = [1, 2];  # println only acts once on this

julia> fmap(println, (i = twice, ii = 34, iii = [5, 6], iv = (twice, 34), v = 34.0))
[1, 2]
34
[5, 6]
34
34.0
(i = nothing, ii = nothing, iii = nothing, iv = (nothing, nothing), v = nothing)

julia> d1 = Dict("x" => [1,2], "y" => 3);

julia> d2 = Dict("x" => [4,5], "y" => 6, "z" => "an_extra_value");

julia> fmap(+, d1, d2) == Dict("x" => [5, 7], "y" => 9) # Note that "z" is ignored
true

Mutable objects which appear more than once are only handled once (by caching f(x) in an IdDict). Thus the relationship x.i === x.iv[1] will be preserved. An immutable object which appears twice is not stored in the cache, thus f(34) will be called twice, and the results will agree only if f is pure.

By default, almost all container-like types have children to recurse into. Arrays of numbers do not.

To opt out of recursion for custom types use @leaf or pass a custom exclude function.

julia> struct Foo; x; y; end

julia> struct Bar; x; end

julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7))));

julia> fmap(x -> 10x, m)
Foo(Bar([10, 20, 30]), (40, 50, Bar(Foo(60, 70))))

julia> fmap(string, m)
Foo(Bar("[1, 2, 3]"), ("4", "5", Bar(Foo("6", "7"))))

julia> fmap(string, m, exclude = v -> v isa Bar)
Foo("Bar([1, 2, 3])", (4, 5, "Bar(Foo(6, 7))"))

To recurse into custom types without reconstructing them afterwards, use fmapstructure.

For advanced customization of the traversal behaviour, pass a custom walk function that subtypes Functors.AbstractWalk. The call fmap(f, x, ys...; walk = mywalk) will wrap mywalk in ExcludeWalk then CachedWalk. Here, ExcludeWalk is responsible for applying f at excluded nodes. For a low-level interface for executing a user-constructed walk, see execute.

julia> struct MyWalk <: Functors.AbstractWalk end

julia> (::MyWalk)(recurse, x) = x isa Bar ? "hello" :
                                            Functors.DefaultWalk()(recurse, x)

julia> fmap(x -> 10x, m; walk = MyWalk())
Foo("hello", (40, 50, "hello"))

The behaviour when the same node appears twice can be altered by giving a value to the prune keyword, which is then used in place of all but the first:

julia> twice = [1, 2];

julia> fmap(float, (x = twice, y = [1,2], z = twice); prune = missing)
(x = [1.0, 2.0], y = [1.0, 2.0], z = missing)
source
Functors.fmap_with_pathFunction
fmap_with_path(f, x, ys...; exclude = isleaf, walk = DefaultWalkWithPath(), [prune])

Like fmap, but also passes a KeyPath to f for each node in the recursion. The KeyPath is a tuple of the indices used to reach the current node from the root of the recursion. The KeyPath is constructed by the walk function, and can be used to reconstruct the path to the current node from the root of the recursion.

f has to accept two arguments: the associated KeyPath and the value of the current node.

exclude also receives the KeyPath as its first argument and a node as its second. It should return true if the recursion should not continue on its children and f applied to it.

prune is used to control the behaviour when the same node appears twice, see fmap for more information.

Examples

julia> x = ([1, 2, 3], 4, (a=5, b=Dict("A"=>6, "B"=>7), c=Dict("C"=>8, "D"=>9)));

julia> exclude(kp, x) = kp == KeyPath(3, :c) || Functors.isleaf(x);

julia> fmap_with_path((kp, x) -> x isa Dict ? nothing : x.^2, x; exclude = exclude)
([1, 4, 9], 16, (a = 25, b = Dict("B" => 49, "A" => 36), c = nothing))
source
Functors.isleafFunction
isleaf(x)

Return true if x has no children according to functor.

Examples

julia> Functors.isleaf(1)
true

julia> Functors.isleaf([2, 3, 4])
true

julia> Functors.isleaf(["five", [6, 7]])
false

julia> Functors.isleaf([])
false

julia> Functors.isleaf((8, 9))
false

julia> Functors.isleaf(())
true
source
Functors.fcollectFunction
fcollect(x; exclude = v -> false)

Traverse x by recursing each child of x as defined by functor and collecting the results into a flat array, ordered by a breadth-first traversal of x, respecting the iteration order of children calls.

Doesn't recurse inside branches rooted at nodes v for which exclude(v) == true. In such cases, the root v is also excluded from the result. By default, exclude always yields false.

See also children.

Examples

julia> struct Foo; x; y; end

julia> struct Bar; x; end

julia> struct TypeWithNoChildren; x; y; end

julia> @leaf TypeWithNoChildren

julia> m = Foo(Bar([1,2,3]), TypeWithNoChildren(:a, :b))
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))

julia> fcollect(m)
4-element Vector{Any}:
 Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
 Bar([1, 2, 3])
 [1, 2, 3]
 TypeWithNoChildren(:a, :b)

julia> fcollect(m, exclude = v -> v isa Bar)
2-element Vector{Any}:
 Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
 TypeWithNoChildren(:a, :b)

julia> fcollect(m, exclude = v -> Functors.isleaf(v))
2-element Vector{Any}:
 Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
 Bar([1, 2, 3])
source
Functors.functorFunction
functor(x)
functor(typeof(x), x)

Returns a tuple containing, first, a NamedTuple of the children of x (typically its fields), and second, a reconstruction function. This controls the behaviour of fmap.

Methods should be added to functor(::Type{T}, x) for custom types, usually using the macro @functor.

source
Functors.fmapstructureFunction
fmapstructure(f, x, ys...; exclude = isleaf, [prune])

Like fmap, but doesn't preserve the type of custom structs. Instead, it returns a NamedTuple (or a Tuple, or an array), or a nested set of these.

Useful for when the output must not contain custom structs.

See also fmap and fmapstructure_with_path.

Examples

julia> struct Foo; x; y; end

julia> m = Foo([1,2,3], [4, (5, 6), Foo(7, 8)]);

julia> fmapstructure(x -> 2x, m)
(x = [2, 4, 6], y = Any[8, (10, 12), (x = 14, y = 16)])

julia> fmapstructure(println, m)
[1, 2, 3]
4
5
6
7
8
(x = nothing, y = Any[nothing, (nothing, nothing), (x = nothing, y = nothing)])
source
Functors.executeFunction
execute(walk, x, ys...)

Execute a walk that recursively calls itself, starting at a node x in a Functors tree, as well as optional associated nodes ys... in other Functors trees. Any custom walk function that subtypes Functors.AbstractWalk is permitted.

source
Functors.AbstractWalkType
AbstractWalk

Any walk for use with fmap should inherit from this type. A walk subtyping AbstractWalk must satisfy the walk function interface:

struct MyWalk <: AbstractWalk end

function (::MyWalk)(recurse, x, ys...)
  # implement this
end

The walk function is called on a node x in a Functors tree. It may also be passed associated nodes ys... in other Functors trees. The walk function recurses further into (x, ys...) by calling recurse on the child nodes. The choice of which nodes to recurse and in what order is custom to the walk.

source
Functors.ExcludeWalkType
ExcludeWalk(walk, fn, exclude)

A walk that recurses nodes (x, ys...) according to walk, except when exclude(x) is true. Then, fn(x, ys...) is applied instead of recursing further.

Typically wraps an existing walk for use with fmap.

source
Functors.CachedWalkType
CachedWalk(walk[; prune])

A walk that recurses nodes (x, ys...) according to walk and storing the output of the recursion in a cache indexed by x (based on object ID). Whenever the cache already contains x, either:

  • prune is specified, then it is returned, or
  • prune is unspecified, and the previously cached recursion of (x, ys...) returned.

Typically wraps an existing walk for use with fmap.

source

Moving models, or data, to the GPU

Flux provides some convenience functions based on fmap. Some (f16, f32, f64) change the precision of all arrays in a model. Others are used for moving a model to of from GPU memory:

Flux.cpuFunction
cpu(m)

Copies m onto the CPU, the opposite of gpu. Recurses into structs (thanks to Functors.jl).

Example

julia> m_gpu = Dense(CUDA.randn(2, 5))
Dense(5 => 2)       # 12 parameters

julia> m_gpu.bias  # matches the given weight matrix
2-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 0.0
 0.0

julia> m = m_gpu |> cpu
Dense(5 => 2)       # 12 parameters

julia> m.bias
2-element Vector{Float32}:
 0.0
 0.0
source
Flux.gpuMethod
gpu(m)

Copies m to the current GPU device (using current GPU backend), if one is available. If no GPU is available, it does nothing (but prints a warning the first time). It recurses into structs according to Functors.jl.

Use cpu to copy back to ordinary Arrays. See also f32 and f16 to change element type only.

This function is just defined for convenience around gpu_device, and is equivalent to gpu_device()(m). You may consider defining device = gpu_device() once and then using device(m) to move data.

Example

julia> m = Dense(rand(2, 3))  # constructed with Float64 weight matrix
Dense(3 => 2)       # 8 parameters

julia> typeof(m.weight)
Matrix{Float64} (alias for Array{Float64, 2})

julia> m_gpu = gpu(m)  # can equivalently be written m_gpu = m |> gpu
Dense(3 => 2)       # 8 parameters

julia> typeof(m_gpu.weight)
CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
source
Flux.gpuMethod
gpu(data::DataLoader)
cpu(data::DataLoader)

Transforms a given DataLoader to apply gpu or cpu to each batch of data, when iterated over. (If no GPU is available, this does nothing.)

Example

julia> dl = Flux.DataLoader((x = ones(2,10), y='a':'j'), batchsize=3)
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}, batchsize=3)
  with first element:
  (; x = 2×3 Matrix{Float64}, y = 3-element StepRange{Char, Int64})

julia> first(dl)
(x = [1.0 1.0 1.0; 1.0 1.0 1.0], y = 'a':1:'c')

julia> c_dl = gpu(dl)
4-element DataLoader(::MLUtils.MappedData{:auto, typeof(gpu), NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}}, batchsize=3)
  with first element:
  (; x = 2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element StepRange{Char, Int64})

julia> first(c_dl).x
2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 1.0  1.0  1.0
 1.0  1.0  1.0

For large datasets, this is preferred over moving all the data to the GPU before creating the DataLoader, like this:

julia> Flux.DataLoader((x = ones(2,10), y=2:11) |> gpu, batchsize=3)
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, UnitRange{Int64}}}, batchsize=3)
  with first element:
  (; x = 2×3 CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element UnitRange{Int64})
Warning

This only works if gpu is applied directly to the DataLoader. While gpu acts recursively on Flux models and many basic Julia structs, it will not work on (say) a tuple of DataLoaders.

source