Recursive transformations from Functors.jl
Flux models are deeply nested structures, and Functors.jl provides tools needed to explore such objects, apply functions to the parameters they contain, and re-build them.
All layers were previously defined with the Functors.@functor
macro. This still works, but it is recommended that you use the new Flux.@layer
macro instead. Both allow Flux.setup
to see the parameters inside, and gpu
to move them to the GPU, but Flux.@layer
also overloads printing, and offers a way to define trainable
at the same time.
Functors.jl
has its own notes on basic usage for more details. Additionally, the Advanced Model Building and Customisation page covers the use cases of Functors
in greater details.
Flux.@layer
— Macro@layer Dense
@layer :expand Chain
@layer BatchNorm trainable=(β,γ)
This macro replaces most uses of @functor
. Its basic purpose is the same: When you define a new layer, this tells Flux to explore inside it to see the parameters it trains, and also to move them to the GPU, change precision, etc.
Like @functor
, this assumes your struct has the default constructor, to enable re-building. If you define an inner constructor (i.e. a function within the struct
block) things may break.
The keyword trainable
allows you to limit this exploration, instead of visiting all fieldnames(T)
. Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes.
The macro also handles overloads of show
for pretty printing.
- By default, it adds methods to 3-arg
Base.show
to treat your layer much likeDense
orConv
. - If your layer is a container, more like
Chain
orParallel
, then:expand
makesshow
unfold its contents. - To disable all
show
overloads, there is an:ignore
option too.
(You probably still want to define 2-arg show(io::IO, x::Layer)
, the macro does not touch this.)
Note that re-running the macro with different options may not remove all methods, you will need to restart.
Example
julia> struct Trio; a; b; c end
julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4))
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
julia> Flux.destructure(tri) # parameters are not yet visible to Flux
(Bool[], Restructure(Trio, ..., 0))
julia> Flux.@layer :expand Trio
julia> Flux.destructure(tri) # now gpu, params, train!, etc will see inside too
([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4))
julia> tri # and layer is printed like Chain
Trio(
Dense(2 => 1, tanh), # 3 parameters
Dense(1 => 1; bias=false), # 1 parameters
Dropout(0.4),
) # Total: 3 arrays, 4 parameters, 224 bytes.
Functors.@functor
— Macro@functor T
@functor T (x,)
Adds methods to functor
allowing recursion into objects of type T
, and reconstruction. Assumes that T
has a constructor accepting all of its fields, which is true unless you have provided an inner constructor which does not.
By default all fields of T
are considered children
; this can be restricted be restructed by providing a tuple of field names.
Examples
julia> struct Foo; x; y; end
julia> @functor Foo
julia> Functors.children(Foo(1,2))
(x = 1, y = 2)
julia> _, re = Functors.functor(Foo(1,2));
julia> re((10, 20))
Foo(10, 20)
julia> struct TwoThirds a; b; c; end
julia> @functor TwoThirds (a, c)
julia> ch2, re3 = Functors.functor(TwoThirds(10,20,30));
julia> ch2
(a = 10, c = 30)
julia> re3(("ten", "thirty"))
TwoThirds("ten", 20, "thirty")
julia> fmap(x -> 10x, TwoThirds(Foo(1,2), Foo(3,4), 56))
TwoThirds(Foo(10, 20), Foo(3, 4), 560)
Functors.fmap
— Functionfmap(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalk(), [prune])
A structure and type preserving map
.
By default it transforms every leaf node (identified by exclude
, default isleaf
) by applying f
, and otherwise traverses x
recursively using functor
. Optionally, it may also be associated with objects ys
with the same tree structure. In that case, f
is applied to the corresponding leaf nodes in x
and ys
.
See also fmap_with_path
and fmapstructure
.
Examples
julia> fmap(string, (x=1, y=(2, 3)))
(x = "1", y = ("2", "3"))
julia> nt = (a = [1,2], b = [23, (45,), (x=6//7, y=())], c = [8,9]);
julia> fmap(println, nt)
[1, 2]
23
45
6//7
()
[8, 9]
(a = nothing, b = Any[nothing, (nothing,), (x = nothing, y = nothing)], c = nothing)
julia> fmap(println, nt; exclude = x -> x isa Array)
[1, 2]
Any[23, (45,), (x = 6//7, y = ())]
[8, 9]
(a = nothing, b = nothing, c = nothing)
julia> twice = [1, 2]; # println only acts once on this
julia> fmap(println, (i = twice, ii = 34, iii = [5, 6], iv = (twice, 34), v = 34.0))
[1, 2]
34
[5, 6]
34
34.0
(i = nothing, ii = nothing, iii = nothing, iv = (nothing, nothing), v = nothing)
julia> d1 = Dict("x" => [1,2], "y" => 3);
julia> d2 = Dict("x" => [4,5], "y" => 6, "z" => "an_extra_value");
julia> fmap(+, d1, d2) == Dict("x" => [5, 7], "y" => 9) # Note that "z" is ignored
true
Mutable objects which appear more than once are only handled once (by caching f(x)
in an IdDict
). Thus the relationship x.i === x.iv[1]
will be preserved. An immutable object which appears twice is not stored in the cache, thus f(34)
will be called twice, and the results will agree only if f
is pure.
By default, Tuple
s, NamedTuple
s, and some other container-like types in Base have children to recurse into. Arrays of numbers do not. To enable recursion into new types, you must provide a method of functor
, which can be done using the macro @functor
:
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7))));
julia> fmap(x -> 10x, m)
Foo(Bar([10, 20, 30]), (40, 50, Bar(Foo(60, 70))))
julia> fmap(string, m)
Foo(Bar("[1, 2, 3]"), ("4", "5", Bar(Foo("6", "7"))))
julia> fmap(string, m, exclude = v -> v isa Bar)
Foo("Bar([1, 2, 3])", (4, 5, "Bar(Foo(6, 7))"))
To recurse into custom types without reconstructing them afterwards, use fmapstructure
.
For advanced customization of the traversal behaviour, pass a custom walk
function that subtypes Functors.AbstractWalk
. The call fmap(f, x, ys...; walk = mywalk)
will wrap mywalk
in ExcludeWalk
then CachedWalk
. Here, ExcludeWalk
is responsible for applying f
at excluded nodes. For a low-level interface for executing a user-constructed walk, see execute
.
julia> struct MyWalk <: Functors.AbstractWalk end
julia> (::MyWalk)(recurse, x) = x isa Bar ? "hello" :
Functors.DefaultWalk()(recurse, x)
julia> fmap(x -> 10x, m; walk = MyWalk())
Foo("hello", (40, 50, "hello"))
The behaviour when the same node appears twice can be altered by giving a value to the prune
keyword, which is then used in place of all but the first:
julia> twice = [1, 2];
julia> fmap(float, (x = twice, y = [1,2], z = twice); prune = missing)
(x = [1.0, 2.0], y = [1.0, 2.0], z = missing)
Functors.fmap_with_path
— Functionfmap_with_path(f, x, ys...; exclude = isleaf, walk = DefaultWalkWithPath(), [prune])
Like fmap
, but also passes a KeyPath
to f
for each node in the recursion. The KeyPath
is a tuple of the indices used to reach the current node from the root of the recursion. The KeyPath
is constructed by the walk
function, and can be used to reconstruct the path to the current node from the root of the recursion.
f
has to accept two arguments: the associated KeyPath
and the value of the current node.
exclude
also receives the KeyPath
as its first argument and a node as its second. It should return true
if the recursion should not continue on its children and f
applied to it.
prune
is used to control the behaviour when the same node appears twice, see fmap
for more information.
Examples
julia> x = ([1, 2, 3], 4, (a=5, b=Dict("A"=>6, "B"=>7), c=Dict("C"=>8, "D"=>9)));
julia> exclude(kp, x) = kp == KeyPath(3, :c) || Functors.isleaf(x);
julia> fmap_with_path((kp, x) -> x isa Dict ? nothing : x.^2, x; exclude = exclude)
([1, 4, 9], 16, (a = 25, b = Dict("B" => 49, "A" => 36), c = nothing))
Functors.isleaf
— FunctionFunctors.isleaf(x)
Return true if x
has no children
according to functor
.
Examples
julia> Functors.isleaf(1)
true
julia> Functors.isleaf([2, 3, 4])
true
julia> Functors.isleaf(["five", [6, 7]])
false
julia> Functors.isleaf([])
false
julia> Functors.isleaf((8, 9))
false
julia> Functors.isleaf(())
true
Functors.children
— FunctionFunctors.children(x)
Return the children of x
as defined by functor
. Equivalent to functor(x)[1]
.
Functors.fcollect
— Functionfcollect(x; exclude = v -> false)
Traverse x
by recursing each child of x
as defined by functor
and collecting the results into a flat array, ordered by a breadth-first traversal of x
, respecting the iteration order of children
calls.
Doesn't recurse inside branches rooted at nodes v
for which exclude(v) == true
. In such cases, the root v
is also excluded from the result. By default, exclude
always yields false
.
See also children
.
Examples
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> struct TypeWithNoChildren; x; y; end
julia> m = Foo(Bar([1,2,3]), TypeWithNoChildren(:a, :b))
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
julia> fcollect(m)
4-element Vector{Any}:
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
Bar([1, 2, 3])
[1, 2, 3]
TypeWithNoChildren(:a, :b)
julia> fcollect(m, exclude = v -> v isa Bar)
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
TypeWithNoChildren(:a, :b)
julia> fcollect(m, exclude = v -> Functors.isleaf(v))
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
Bar([1, 2, 3])
Functors.functor
— FunctionFunctors.functor(x) = functor(typeof(x), x)
Returns a tuple containing, first, a NamedTuple
of the children of x
(typically its fields), and second, a reconstruction funciton. This controls the behaviour of fmap
.
Methods should be added to functor(::Type{T}, x)
for custom types, usually using the macro @functor
.
Functors.fmapstructure
— Functionfmapstructure(f, x, ys...; exclude = isleaf, [prune])
Like fmap
, but doesn't preserve the type of custom structs. Instead, it returns a NamedTuple
(or a Tuple
, or an array), or a nested set of these.
Useful for when the output must not contain custom structs.
See also fmap
and fmapstructure_with_path
.
Examples
julia> struct Foo; x; y; end
julia> @functor Foo
julia> m = Foo([1,2,3], [4, (5, 6), Foo(7, 8)]);
julia> fmapstructure(x -> 2x, m)
(x = [2, 4, 6], y = Any[8, (10, 12), (x = 14, y = 16)])
julia> fmapstructure(println, m)
[1, 2, 3]
4
5
6
7
8
(x = nothing, y = Any[nothing, (nothing, nothing), (x = nothing, y = nothing)])
Functors.fmapstructure_with_path
— Functionfmapstructure_with_path(f, x, ys...; [exclude, prune])
Like fmap_with_path
, but doesn't preserve the type of custom structs. Instead, it returns a named tuple, a tuple, an array, a dict, or a nested set of these.
See also fmapstructure
.
Functors.execute
— Functionexecute(walk, x, ys...)
Execute a walk
that recursively calls itself, starting at a node x
in a Functors tree, as well as optional associated nodes ys...
in other Functors trees. Any custom walk
function that subtypes Functors.AbstractWalk
is permitted.
Functors.AbstractWalk
— TypeAbstractWalk
Any walk for use with fmap
should inherit from this type. A walk subtyping AbstractWalk
must satisfy the walk function interface:
struct MyWalk <: AbstractWalk end
function (::MyWalk)(recurse, x, ys...)
# implement this
end
The walk function is called on a node x
in a Functors tree. It may also be passed associated nodes ys...
in other Functors trees. The walk function recurses further into (x, ys...)
by calling recurse
on the child nodes. The choice of which nodes to recurse and in what order is custom to the walk.
Functors.ExcludeWalk
— TypeExcludeWalk(walk, fn, exclude)
A walk that recurses nodes (x, ys...)
according to walk
, except when exclude(x)
is true. Then, fn(x, ys...)
is applied instead of recursing further.
Typically wraps an existing walk
for use with fmap
.
Functors.CachedWalk
— TypeCachedWalk(walk[; prune])
A walk that recurses nodes (x, ys...)
according to walk
and storing the output of the recursion in a cache indexed by x
(based on object ID). Whenever the cache already contains x
, either:
prune
is specified, then it is returned, orprune
is unspecified, and the previously cached recursion of(x, ys...)
returned.
Typically wraps an existing walk
for use with fmap
.
Moving models, or data, to the GPU
Flux provides some convenience functions based on fmap
. Some (f16
, f32
, f64
) change the precision of all arrays in a model. Others are used for moving a model to of from GPU memory:
Flux.cpu
— Functioncpu(m)
Copies m
onto the CPU, the opposite of gpu
. Recurses into structs marked @functor
.
Example
julia> m_gpu = Dense(CUDA.randn(2, 5))
Dense(5 => 2) # 12 parameters
julia> m_gpu.bias # matches the given weight matrix
2-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
0.0
0.0
julia> m = m_gpu |> cpu
Dense(5 => 2) # 12 parameters
julia> m.bias
2-element Vector{Float32}:
0.0
0.0
Flux.gpu
— Methodgpu(m)
Copies m
to the current GPU device (using current GPU backend), if one is available. If no GPU is available, it does nothing (but prints a warning the first time).
On arrays, this calls CUDA's cu
, which also changes arrays with Float64 elements to Float32 while copying them to the device (same for AMDGPU). To act on arrays within a struct, the struct type must be marked with @functor
.
Use cpu
to copy back to ordinary Array
s. See also f32
and f16
to change element type only.
See the CUDA.jl docs to help identify the current device.
Example
julia> m = Dense(rand(2, 3)) # constructed with Float64 weight matrix
Dense(3 => 2) # 8 parameters
julia> typeof(m.weight)
Matrix{Float64} (alias for Array{Float64, 2})
julia> m_gpu = gpu(m) # can equivalently be written m_gpu = m |> gpu
Dense(3 => 2) # 8 parameters
julia> typeof(m_gpu.weight)
CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
Flux.gpu
— Methodgpu(data::DataLoader)
cpu(data::DataLoader)
Transforms a given DataLoader
to apply gpu
or cpu
to each batch of data, when iterated over. (If no GPU is available, this does nothing.)
Example
julia> dl = Flux.DataLoader((x = ones(2,10), y='a':'j'), batchsize=3)
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}, batchsize=3)
with first element:
(; x = 2×3 Matrix{Float64}, y = 3-element StepRange{Char, Int64})
julia> first(dl)
(x = [1.0 1.0 1.0; 1.0 1.0 1.0], y = 'a':1:'c')
julia> c_dl = gpu(dl)
4-element DataLoader(::MLUtils.MappedData{:auto, typeof(gpu), NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}}, batchsize=3)
with first element:
(; x = 2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element StepRange{Char, Int64})
julia> first(c_dl).x
2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
1.0 1.0 1.0
1.0 1.0 1.0
For large datasets, this is preferred over moving all the data to the GPU before creating the DataLoader
, like this:
julia> Flux.DataLoader((x = ones(2,10), y=2:11) |> gpu, batchsize=3)
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, UnitRange{Int64}}}, batchsize=3)
with first element:
(; x = 2×3 CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element UnitRange{Int64})
This only works if gpu
is applied directly to the DataLoader
. While gpu
acts recursively on Flux models and many basic Julia structs, it will not work on (say) a tuple of DataLoader
s.