Parallel
struct
defined in module
Flux
Parallel(connection, layers...)
Parallel(connection; name = layer, ...)
Create a layer which passes an input array to each path in
layers
, before reducing the output with
connection
.
Called with one input
x
, this is equivalent to
connection([l(x) for l in layers]...)
. If called with multiple inputs, one is passed to each layer, thus
Parallel(+, f, g)(x, y) = f(x) + g(y)
.
Like
Chain
, its sub-layers may be given names using the keyword constructor. These can be accessed by indexing:
m[1] == m[:name]
is the first layer.
See also
SkipConnection
which is
Parallel
with one
identity
, and
Maxout
which reduces by broadcasting
max
.
julia> model = Chain(Dense(3 => 5),
Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))),
Dense(8 => 17));
julia> model(rand32(3)) |> size
(17,)
julia> model2 = Parallel(+; α = Dense(10, 2, tanh), β = Dense(5, 2))
Parallel(
+,
α = Dense(10 => 2, tanh), # 22 parameters
β = Dense(5 => 2), # 12 parameters
) # Total: 4 arrays, 34 parameters, 392 bytes.
julia> model2(rand32(10), rand32(5)) |> size
(2,)
julia> model2[:α](rand32(10)) |> size
(2,)
julia> model2[:β] == model2[2]
true
There are
3
methods for Flux.Parallel
:
The following pages link back here:
models/layers.jl , models/unet.jl , Flux.jl , layers/basic.jl , layers/show.jl