# Recursive transformations from Functors.jl

Flux models are deeply nested structures, and Functors.jl provides tools needed to explore such objects, apply functions to the parameters they contain, and re-build them.

All layers were previously defined with the `Functors.@functor`

macro. This still works, but it is recommended that you use the new `Flux.@layer`

macro instead. Both allow `Flux.setup`

to see the parameters inside, and `gpu`

to move them to the GPU, but `Flux.@layer`

also overloads printing, and offers a way to define `trainable`

at the same time.

`Functors.jl`

has its own notes on basic usage for more details. Additionally, the Advanced Model Building and Customisation page covers the use cases of `Functors`

in greater details.

`Flux.@layer`

— Macro```
@layer Dense
@layer :expand Chain
@layer BatchNorm trainable=(β,γ)
```

This macro replaces most uses of `@functor`

. Its basic purpose is the same: When you define a new layer, this tells Flux to explore inside it to see the parameters it trains, and also to move them to the GPU, change precision, etc.

Like `@functor`

, this assumes your struct has the default constructor, to enable re-building. If you define an inner constructor (i.e. a function within the `struct`

block) things may break.

The keyword `trainable`

allows you to limit this exploration, instead of visiting all `fieldnames(T)`

. Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes.

The macro also handles overloads of `show`

for pretty printing.

- By default, it adds methods to 3-arg
`Base.show`

to treat your layer much like`Dense`

or`Conv`

. - If your layer is a container, more like
`Chain`

or`Parallel`

, then`:expand`

makes`show`

unfold its contents. - To disable all
`show`

overloads, there is an`:ignore`

option too.

(You probably still want to define 2-arg `show(io::IO, x::Layer)`

, the macro does not touch this.)

Note that re-running the macro with different options may not remove all methods, you will need to restart.

**Example**

```
julia> struct Trio; a; b; c end
julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4))
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
julia> Flux.destructure(tri) # parameters are not yet visible to Flux
(Bool[], Restructure(Trio, ..., 0))
julia> Flux.@layer :expand Trio
julia> Flux.destructure(tri) # now gpu, params, train!, etc will see inside too
([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4))
julia> tri # and layer is printed like Chain
Trio(
Dense(2 => 1, tanh), # 3 parameters
Dense(1 => 1; bias=false), # 1 parameters
Dropout(0.4),
) # Total: 3 arrays, 4 parameters, 224 bytes.
```

`Functors.@functor`

— Macro```
@functor T
@functor T (x,)
```

Adds methods to `functor`

allowing recursion into objects of type `T`

, and reconstruction. Assumes that `T`

has a constructor accepting all of its fields, which is true unless you have provided an inner constructor which does not.

By default all fields of `T`

are considered `children`

; this can be restricted be restructed by providing a tuple of field names.

**Examples**

```
julia> struct Foo; x; y; end
julia> @functor Foo
julia> Functors.children(Foo(1,2))
(x = 1, y = 2)
julia> _, re = Functors.functor(Foo(1,2));
julia> re((10, 20))
Foo(10, 20)
julia> struct TwoThirds a; b; c; end
julia> @functor TwoThirds (a, c)
julia> ch2, re3 = Functors.functor(TwoThirds(10,20,30));
julia> ch2
(a = 10, c = 30)
julia> re3(("ten", "thirty"))
TwoThirds("ten", 20, "thirty")
julia> fmap(x -> 10x, TwoThirds(Foo(1,2), Foo(3,4), 56))
TwoThirds(Foo(10, 20), Foo(3, 4), 560)
```

`Functors.fmap`

— Function`fmap(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalk(), [prune])`

A structure and type preserving `map`

.

By default it transforms every leaf node (identified by `exclude`

, default `isleaf`

) by applying `f`

, and otherwise traverses `x`

recursively using `functor`

. Optionally, it may also be associated with objects `ys`

with the same tree structure. In that case, `f`

is applied to the corresponding leaf nodes in `x`

and `ys`

.

See also `fmap_with_path`

and `fmapstructure`

.

**Examples**

```
julia> fmap(string, (x=1, y=(2, 3)))
(x = "1", y = ("2", "3"))
julia> nt = (a = [1,2], b = [23, (45,), (x=6//7, y=())], c = [8,9]);
julia> fmap(println, nt)
[1, 2]
23
45
6//7
()
[8, 9]
(a = nothing, b = Any[nothing, (nothing,), (x = nothing, y = nothing)], c = nothing)
julia> fmap(println, nt; exclude = x -> x isa Array)
[1, 2]
Any[23, (45,), (x = 6//7, y = ())]
[8, 9]
(a = nothing, b = nothing, c = nothing)
julia> twice = [1, 2]; # println only acts once on this
julia> fmap(println, (i = twice, ii = 34, iii = [5, 6], iv = (twice, 34), v = 34.0))
[1, 2]
34
[5, 6]
34
34.0
(i = nothing, ii = nothing, iii = nothing, iv = (nothing, nothing), v = nothing)
julia> d1 = Dict("x" => [1,2], "y" => 3);
julia> d2 = Dict("x" => [4,5], "y" => 6, "z" => "an_extra_value");
julia> fmap(+, d1, d2) == Dict("x" => [5, 7], "y" => 9) # Note that "z" is ignored
true
```

Mutable objects which appear more than once are only handled once (by caching `f(x)`

in an `IdDict`

). Thus the relationship `x.i === x.iv[1]`

will be preserved. An immutable object which appears twice is not stored in the cache, thus `f(34)`

will be called twice, and the results will agree only if `f`

is pure.

By default, `Tuple`

s, `NamedTuple`

s, and some other container-like types in Base have children to recurse into. Arrays of numbers do not. To enable recursion into new types, you must provide a method of `functor`

, which can be done using the macro `@functor`

:

```
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7))));
julia> fmap(x -> 10x, m)
Foo(Bar([10, 20, 30]), (40, 50, Bar(Foo(60, 70))))
julia> fmap(string, m)
Foo(Bar("[1, 2, 3]"), ("4", "5", Bar(Foo("6", "7"))))
julia> fmap(string, m, exclude = v -> v isa Bar)
Foo("Bar([1, 2, 3])", (4, 5, "Bar(Foo(6, 7))"))
```

To recurse into custom types without reconstructing them afterwards, use `fmapstructure`

.

For advanced customization of the traversal behaviour, pass a custom `walk`

function that subtypes `Functors.AbstractWalk`

. The call `fmap(f, x, ys...; walk = mywalk)`

will wrap `mywalk`

in `ExcludeWalk`

then `CachedWalk`

. Here, `ExcludeWalk`

is responsible for applying `f`

at excluded nodes. For a low-level interface for executing a user-constructed walk, see `execute`

.

```
julia> struct MyWalk <: Functors.AbstractWalk end
julia> (::MyWalk)(recurse, x) = x isa Bar ? "hello" :
Functors.DefaultWalk()(recurse, x)
julia> fmap(x -> 10x, m; walk = MyWalk())
Foo("hello", (40, 50, "hello"))
```

The behaviour when the same node appears twice can be altered by giving a value to the `prune`

keyword, which is then used in place of all but the first:

```
julia> twice = [1, 2];
julia> fmap(float, (x = twice, y = [1,2], z = twice); prune = missing)
(x = [1.0, 2.0], y = [1.0, 2.0], z = missing)
```

`Functors.fmap_with_path`

— Function`fmap_with_path(f, x, ys...; exclude = isleaf, walk = DefaultWalkWithPath(), [prune])`

Like `fmap`

, but also passes a `KeyPath`

to `f`

for each node in the recursion. The `KeyPath`

is a tuple of the indices used to reach the current node from the root of the recursion. The `KeyPath`

is constructed by the `walk`

function, and can be used to reconstruct the path to the current node from the root of the recursion.

`f`

has to accept two arguments: the associated `KeyPath`

and the value of the current node.

`exclude`

also receives the `KeyPath`

as its first argument and a node as its second. It should return `true`

if the recursion should not continue on its children and `f`

applied to it.

`prune`

is used to control the behaviour when the same node appears twice, see `fmap`

for more information.

**Examples**

```
julia> x = ([1, 2, 3], 4, (a=5, b=Dict("A"=>6, "B"=>7), c=Dict("C"=>8, "D"=>9)));
julia> exclude(kp, x) = kp == KeyPath(3, :c) || Functors.isleaf(x);
julia> fmap_with_path((kp, x) -> x isa Dict ? nothing : x.^2, x; exclude = exclude)
([1, 4, 9], 16, (a = 25, b = Dict("B" => 49, "A" => 36), c = nothing))
```

`Functors.isleaf`

— Function`Functors.isleaf(x)`

Return true if `x`

has no `children`

according to `functor`

.

**Examples**

```
julia> Functors.isleaf(1)
true
julia> Functors.isleaf([2, 3, 4])
true
julia> Functors.isleaf(["five", [6, 7]])
false
julia> Functors.isleaf([])
false
julia> Functors.isleaf((8, 9))
false
julia> Functors.isleaf(())
true
```

`Functors.children`

— Function`Functors.children(x)`

Return the children of `x`

as defined by `functor`

. Equivalent to `functor(x)[1]`

.

`Functors.fcollect`

— Function`fcollect(x; exclude = v -> false)`

Traverse `x`

by recursing each child of `x`

as defined by `functor`

and collecting the results into a flat array, ordered by a breadth-first traversal of `x`

, respecting the iteration order of `children`

calls.

Doesn't recurse inside branches rooted at nodes `v`

for which `exclude(v) == true`

. In such cases, the root `v`

is also excluded from the result. By default, `exclude`

always yields `false`

.

See also `children`

.

**Examples**

```
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> struct TypeWithNoChildren; x; y; end
julia> m = Foo(Bar([1,2,3]), TypeWithNoChildren(:a, :b))
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
julia> fcollect(m)
4-element Vector{Any}:
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
Bar([1, 2, 3])
[1, 2, 3]
TypeWithNoChildren(:a, :b)
julia> fcollect(m, exclude = v -> v isa Bar)
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
TypeWithNoChildren(:a, :b)
julia> fcollect(m, exclude = v -> Functors.isleaf(v))
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
Bar([1, 2, 3])
```

`Functors.functor`

— Function`Functors.functor(x) = functor(typeof(x), x)`

Returns a tuple containing, first, a `NamedTuple`

of the children of `x`

(typically its fields), and second, a reconstruction funciton. This controls the behaviour of `fmap`

.

Methods should be added to `functor(::Type{T}, x)`

for custom types, usually using the macro `@functor`

.

`Functors.fmapstructure`

— Function`fmapstructure(f, x, ys...; exclude = isleaf, [prune])`

Like `fmap`

, but doesn't preserve the type of custom structs. Instead, it returns a `NamedTuple`

(or a `Tuple`

, or an array), or a nested set of these.

Useful for when the output must not contain custom structs.

See also `fmap`

and `fmapstructure_with_path`

.

**Examples**

```
julia> struct Foo; x; y; end
julia> @functor Foo
julia> m = Foo([1,2,3], [4, (5, 6), Foo(7, 8)]);
julia> fmapstructure(x -> 2x, m)
(x = [2, 4, 6], y = Any[8, (10, 12), (x = 14, y = 16)])
julia> fmapstructure(println, m)
[1, 2, 3]
4
5
6
7
8
(x = nothing, y = Any[nothing, (nothing, nothing), (x = nothing, y = nothing)])
```

`Functors.fmapstructure_with_path`

— Function`fmapstructure_with_path(f, x, ys...; [exclude, prune])`

Like `fmap_with_path`

, but doesn't preserve the type of custom structs. Instead, it returns a named tuple, a tuple, an array, a dict, or a nested set of these.

See also `fmapstructure`

.

`Functors.execute`

— Function`execute(walk, x, ys...)`

Execute a `walk`

that recursively calls itself, starting at a node `x`

in a Functors tree, as well as optional associated nodes `ys...`

in other Functors trees. Any custom `walk`

function that subtypes `Functors.AbstractWalk`

is permitted.

`Functors.AbstractWalk`

— Type`AbstractWalk`

Any walk for use with `fmap`

should inherit from this type. A walk subtyping `AbstractWalk`

must satisfy the walk function interface:

```
struct MyWalk <: AbstractWalk end
function (::MyWalk)(recurse, x, ys...)
# implement this
end
```

The walk function is called on a node `x`

in a Functors tree. It may also be passed associated nodes `ys...`

in other Functors trees. The walk function recurses further into `(x, ys...)`

by calling `recurse`

on the child nodes. The choice of which nodes to recurse and in what order is custom to the walk.

`Functors.ExcludeWalk`

— Type`ExcludeWalk(walk, fn, exclude)`

A walk that recurses nodes `(x, ys...)`

according to `walk`

, except when `exclude(x)`

is true. Then, `fn(x, ys...)`

is applied instead of recursing further.

Typically wraps an existing `walk`

for use with `fmap`

.

`Functors.CachedWalk`

— Type`CachedWalk(walk[; prune])`

A walk that recurses nodes `(x, ys...)`

according to `walk`

and storing the output of the recursion in a cache indexed by `x`

(based on object ID). Whenever the cache already contains `x`

, either:

`prune`

is specified, then it is returned, or`prune`

is unspecified, and the previously cached recursion of`(x, ys...)`

returned.

Typically wraps an existing `walk`

for use with `fmap`

.

## Moving models, or data, to the GPU

Flux provides some convenience functions based on `fmap`

. Some (`f16`

, `f32`

, `f64`

) change the precision of all arrays in a model. Others are used for moving a model to of from GPU memory:

`Flux.cpu`

— Function`cpu(m)`

Copies `m`

onto the CPU, the opposite of `gpu`

. Recurses into structs marked `@functor`

.

**Example**

```
julia> m_gpu = Dense(CUDA.randn(2, 5))
Dense(5 => 2) # 12 parameters
julia> m_gpu.bias # matches the given weight matrix
2-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
0.0
0.0
julia> m = m_gpu |> cpu
Dense(5 => 2) # 12 parameters
julia> m.bias
2-element Vector{Float32}:
0.0
0.0
```

`Flux.gpu`

— Method`gpu(m)`

Copies `m`

to the current GPU device (using current GPU backend), if one is available. If no GPU is available, it does nothing (but prints a warning the first time).

On arrays, this calls CUDA's `cu`

, which also changes arrays with Float64 elements to Float32 while copying them to the device (same for AMDGPU). To act on arrays within a struct, the struct type must be marked with `@functor`

.

Use `cpu`

to copy back to ordinary `Array`

s. See also `f32`

and `f16`

to change element type only.

See the CUDA.jl docs to help identify the current device.

**Example**

```
julia> m = Dense(rand(2, 3)) # constructed with Float64 weight matrix
Dense(3 => 2) # 8 parameters
julia> typeof(m.weight)
Matrix{Float64} (alias for Array{Float64, 2})
julia> m_gpu = gpu(m) # can equivalently be written m_gpu = m |> gpu
Dense(3 => 2) # 8 parameters
julia> typeof(m_gpu.weight)
CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
```

`Flux.gpu`

— Method```
gpu(data::DataLoader)
cpu(data::DataLoader)
```

Transforms a given `DataLoader`

to apply `gpu`

or `cpu`

to each batch of data, when iterated over. (If no GPU is available, this does nothing.)

**Example**

```
julia> dl = Flux.DataLoader((x = ones(2,10), y='a':'j'), batchsize=3)
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}, batchsize=3)
with first element:
(; x = 2×3 Matrix{Float64}, y = 3-element StepRange{Char, Int64})
julia> first(dl)
(x = [1.0 1.0 1.0; 1.0 1.0 1.0], y = 'a':1:'c')
julia> c_dl = gpu(dl)
4-element DataLoader(::MLUtils.MappedData{:auto, typeof(gpu), NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}}, batchsize=3)
with first element:
(; x = 2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element StepRange{Char, Int64})
julia> first(c_dl).x
2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
1.0 1.0 1.0
1.0 1.0 1.0
```

For large datasets, this is preferred over moving all the data to the GPU before creating the `DataLoader`

, like this:

```
julia> Flux.DataLoader((x = ones(2,10), y=2:11) |> gpu, batchsize=3)
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, UnitRange{Int64}}}, batchsize=3)
with first element:
(; x = 2×3 CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element UnitRange{Int64})
```

This only works if `gpu`

is applied directly to the `DataLoader`

. While `gpu`

acts recursively on Flux models and many basic Julia structs, it will not work on (say) a tuple of `DataLoader`

s.