Functors.jl
Flux makes use of the Functors.jl to represent many of the core functionalities it provides.
Functors.jl is a collection of tools designed to represent a functor. Flux makes use of it to treat certain structs as functors. Notable examples include the layers that Flux defines.
Functors.isleaf
— FunctionFunctors.children
— FunctionFunctors.children(x)
Return the children of x
as defined by functor
. Equivalent to functor(x)[1]
.
Functors.fcollect
— Functionfcollect(x; exclude = v -> false)
Traverse x
by recursing each child of x
as defined by functor
and collecting the results into a flat array, ordered by a breadth-first traversal of x
, respecting the iteration order of children
calls.
Doesn't recurse inside branches rooted at nodes v
for which exclude(v) == true
. In such cases, the root v
is also excluded from the result. By default, exclude
always yields false
.
See also children
.
Examples
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> struct NoChildren; x; y; end
julia> m = Foo(Bar([1,2,3]), NoChildren(:a, :b))
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
julia> fcollect(m)
4-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
Bar([1, 2, 3])
[1, 2, 3]
NoChildren(:a, :b)
julia> fcollect(m, exclude = v -> v isa Bar)
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
NoChildren(:a, :b)
julia> fcollect(m, exclude = v -> Functors.isleaf(v))
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
Bar([1, 2, 3])
Functors.functor
— FunctionFunctors.functor(x) = functor(typeof(x), x)
Returns a tuple containing, first, a NamedTuple
of the children of x
(typically its fields), and second, a reconstruction funciton. This controls the behaviour of fmap
.
Methods should be added to functor(::Type{T}, x)
for custom types, usually using the macro @functor.
Functors.fmap
— Functionfmap(f, x; exclude = Functors.isleaf, walk = Functors._default_walk)
A structure and type preserving map
.
By default it transforms every leaf node (identified by exclude
, default isleaf
) by applying f
, and otherwise traverses x
recursively using functor
.
Examples
julia> fmap(string, (x=1, y=(2, 3)))
(x = "1", y = ("2", "3"))
julia> nt = (a = [1,2], b = [23, (45,), (x=6//7, y=())], c = [8,9]);
julia> fmap(println, nt)
[1, 2]
23
45
6//7
()
[8, 9]
(a = nothing, b = Any[nothing, (nothing,), (x = nothing, y = nothing)], c = nothing)
julia> fmap(println, nt; exclude = x -> x isa Array)
[1, 2]
Any[23, (45,), (x = 6//7, y = ())]
[8, 9]
(a = nothing, b = nothing, c = nothing)
julia> twice = [1, 2];
julia> fmap(println, (i = twice, ii = 34, iii = [5, 6], iv = (twice, 34), v = 34.0))
[1, 2]
34
[5, 6]
34.0
(i = nothing, ii = nothing, iii = nothing, iv = (nothing, nothing), v = nothing)
If the same node (same according to ===
) appears more than once, it will only be handled once, and only be transformed once with f
. Thus the result will also have this relationship.
By default, Tuple
s, NamedTuple
s, and some other container-like types in Base have children to recurse into. Arrays of numbers do not. To enable recursion into new types, you must provide a method of functor
, which can be done using the macro @functor
:
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7))));
julia> fmap(x -> 10x, m)
Foo(Bar([10, 20, 30]), (40, 50, Bar(Foo(60, 70))))
julia> fmap(string, m)
Foo(Bar("[1, 2, 3]"), ("4", "5", Bar(Foo("6", "7"))))
julia> fmap(string, m, exclude = v -> v isa Bar)
Foo("Bar([1, 2, 3])", (4, 5, "Bar(Foo(6, 7))"))
To recurse into custom types without reconstructing them afterwards, use fmapstructure
.
For advanced customization of the traversal behaviour, pass a custom walk
function of the form (f', xs) -> ...
. This function walks (maps) over xs
calling the continuation f'
to continue traversal.
julia> fmap(x -> 10x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x))
Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7))))
The behaviour when the same node appears twice can be altered by giving a value to the prune
keyword, which is then used in place of all but the first:
julia> twice = [1, 2];
julia> fmap(float, (x = twice, y = [1,2], z = twice); prune = missing)
(x = [1.0, 2.0], y = [1.0, 2.0], z = missing)