Chain
struct
defined in module
Flux
Chain(layers...)
Chain(name = layer, ...)
Collects multiple layers / functions to be called in sequence on a given input. Supports indexing and slicing,
m[2]
or
m[1:end-1]
, and if names are given,
m[:name] == m[1]
etc.
julia> m = Chain(x -> x^2, x -> x+1);
julia> m(5) == 26
true
julia> m = Chain(Dense(10 => 5, tanh), Dense(5 => 2));
julia> x = rand32(10, 32);
julia> m(x) == m[2](m[1](x))
true
julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10 => 5, tanh)),
dec = Dense(5 => 2));
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
For large models, there is a special type-unstable path which can reduce compilation times. This can be used by supplying a vector of layers
Chain([layer1, layer2, ...])
. This feature is somewhat experimental, beware!
There are
3
methods for Flux.Chain
:
The following pages link back here:
Custom learning tasks, Keypoint regression, Variational autoencoders
training/discriminativelrs.jl , training/finetune.jl , training/paramgroups.jl , models.jl , models/blocks.jl , models/unet.jl , models/xresnet.jl , Flux.jl , deprecations.jl , layers/basic.jl , layers/show.jl , outputsize.jl