# basic.jl

Flux/layers/basic.jl is a source file in module `Flux`
``````

"""

Chain(layers...)

Chain(name = layer, ...)

Collects multiple layers / functions to be called in sequence

on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`,

and if names are given, `m[:name] == m[1]` etc.

# Examples

```jldoctest

julia> m = Chain(x -> x^2, x -> x+1);

julia> m(5) == 26

true

julia> m = Chain(Dense(10 => 5, tanh), Dense(5 => 2));

julia> x = rand32(10, 32);

julia> m(x) == m[2](m[1](x))

true

julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10 => 5, tanh)),

dec = Dense(5 => 2));

julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)

true

```

For large models, there is a special type-unstable path which can reduce compilation

times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.

This feature is somewhat experimental, beware!

"""

struct

Chain
{

T
<:

Union
{
Tuple
,

NamedTuple
,

AbstractVector
}
}

layers
::
T

end

Chain
(

xs
...
)

=

Chain
(
xs
)

function

Chain
(

;

kw
...
)

:
layers

in

keys
(
kw
)

&&

throw
(

ArgumentError
(

"
a Chain cannot have a named layer called `layers`
"
)
)

isempty
(
kw
)

&&

return

Chain
(

(
)
)

Chain
(

values
(
kw
)
)

end

@
forward

Chain
.

layers

Base
.

getindex
,

Base
.

length
,

Base
.

first
,

Base
.

last
,

Base
.

iterate
,

Base
.

lastindex
,

Base
.

keys
,

Base
.

firstindex

@
functor

Chain

(

c
::

Chain
)
(
x
)

=

_applychain
(

c
.

layers
,

x
)

@
generated

function

_applychain
(

layers
::

Tuple
{

Vararg
{
Any
,
N
}
}
,

x
)

where

{
N
}

symbols

=

vcat
(

:
x
,

[

gensym
(
)

for

_

in

1
:
N
]
)

calls

=

[

:
(

\$
(

symbols
[

i
+
1
]
)

=

layers
[

\$
i
]
(

\$
(

symbols
[
i
]
)
)
)

for

i

in

1
:
N
]

Expr
(

:
block
,

calls
...
)

end

_applychain
(

layers
::
NamedTuple
,

x
)

=

_applychain
(

Tuple
(
layers
)
,

x
)

function

_applychain
(

layers
::
AbstractVector
,

x
)

# type-unstable path, helps compile times

for

f

in

layers

x

=

f
(
x
)

end

x

end

Base
.

getindex
(

c
::

Chain
,

i
::
AbstractArray
)

=

Chain
(

c
.

layers
[
i
]
)

Base
.

getindex
(

c
::

Chain
{

<:
NamedTuple
}
,

i
::
AbstractArray
)

=

Chain
(

NamedTuple
{

keys
(
c
)
[
i
]
}
(

Tuple
(

c
.

layers
)
[
i
]
)
)

function

Base
.

show
(

io
::
IO
,

c
::

Chain
)

print
(
io
,

"
Chain(
"
)

_show_layers
(
io
,

c
.

layers
)

print
(
io
,

"
)
"
)

end

_show_layers
(
io
,

layers
::
Tuple
)

=

join
(
io
,

layers
,

"
,
"
)

_show_layers
(
io
,

layers
::
NamedTuple
)

=

join
(
io
,

[

"
\$
k
=
\$
v
"

for

(
k
,

v
)

in

pairs
(
layers
)
]
,

"
,
"
)

_show_layers
(
io
,

layers
::
AbstractVector
)

=

(

print
(
io
,

"
[
"
)
;

join
(
io
,

layers
,

"
,
"
)
;

print
(
io
,

"
]
"
)
)``````

This is a temporary and naive implementation it might be replaced in the future for better performance see issue https://github.com/FluxML/Flux.jl/issues/702 Johnny Chen -- only slightly changed to better handle interaction with Zygote

``````

"""

activations(c::Chain, input)

Like calling a `Chain`, but saves the result of each layer as an output.

# Examples

```jldoctest

julia> using Flux: activations

julia> c = Chain(x -> x + 1, x -> x * 2, x -> x ^ 3);

julia> activations(c, 1)

(2, 4, 64)

```

"""

activations
(

c
::

Chain
,

input
)

=

_extraChain
(

Tuple
(

c
.

layers
)
,

input
)``````

Calculates the forward results of each layer provided in a ``` Tuple``` with ``` x``` as model input.

``````

function

_extraChain
(

fs
::
Tuple
,

x
)

res

=

first
(
fs
)
(
x
)

return

(
res
,

_extraChain
(

Base
.

tail
(
fs
)
,

res
)
...
)

end

_extraChain
(

::

Tuple
{
}
,

x
)

=

(
)

"""

Dense(in => out, σ=identity; bias=true, init=glorot_uniform)

Dense(W::AbstractMatrix, [bias, σ])

Create a traditional fully connected layer, whose forward pass is given by:

y = σ.(W * x .+ bias)

The input `x` should be a vector of length `in`, or batch of vectors represented

as an `in × N` matrix, or any array with `size(x,1) == in`.

The out `y` will be a vector  of length `out`, or a batch with

`size(y) == (out, size(x)[2:end]...)`

Keyword `bias=false` will switch off trainable bias for the layer.

The initialisation of the weight matrix is `W = init(out, in)`, calling the function

given to keyword `init`, with default [`glorot_uniform`](@ref Flux.glorot_uniform).

The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.

# Examples

```jldoctest

julia> d = Dense(5 => 2)

Dense(5 => 2)       # 12 parameters

julia> d(rand32(5, 64)) |> size

(2, 64)

julia> d(rand32(5, 6, 4, 64)) |> size  # treated as three batch dimensions

(2, 6, 4, 64)

julia> d1 = Dense(ones(2, 5), false, tanh)  # using provided weight matrix

Dense(5 => 2, tanh; bias=false)  # 10 parameters

julia> d1(ones(5))

2-element Vector{Float64}:

0.9999092042625951

0.9999092042625951

julia> Flux.params(d1)  # no trainable bias

Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])

```

"""

struct

Dense
{
F
,

M
<:
AbstractMatrix
,

B
}

weight
::
M

bias
::
B

σ
::
F

function

Dense
(

W
::
M
,

bias

=

true
,

σ
::
F

=

identity
)

where

{

M
<:
AbstractMatrix
,

F
}

b

=

create_bias
(
W
,

bias
,

size
(
W
,
1
)
)

new
{
F
,
M
,

typeof
(
b
)
}
(
W
,

b
,

σ
)

end

end

function

Dense
(

(
in
,

out
)
::

Pair
{

<:
Integer
,

<:
Integer
}
,

σ

=

identity

;

init

=

glorot_uniform
,

bias

=

true
)

Dense
(

init
(
out
,

in
)
,

bias
,

σ
)

end

@
functor

Dense

function

(

a
::

Dense
)
(

x
::
AbstractVecOrMat
)

_size_check
(
a
,

x
,

1

=>

size
(

a
.

weight
,

2
)
)

σ

=

NNlib
.

fast_act
(

a
.

σ
,

x
)

# replaces tanh => tanh_fast, etc

xT

=

_match_eltype
(
a
,

x
)

# fixes Float64 input, etc.

return

σ
.

(

a
.

weight

*

xT

.+

a
.

bias
)

end

function

(

a
::

Dense
)
(

x
::
AbstractArray
)

_size_check
(
a
,

x
,

1

=>

size
(

a
.

weight
,

2
)
)

reshape
(

a
(

reshape
(
x
,

size
(
x
,
1
)
,

:
)
)
,

:
,

size
(
x
)
[

2
:
end
]
...
)

end

function

Base
.

show
(

io
::
IO
,

l
::

Dense
)

print
(
io
,

"
Dense(
"
,

size
(

l
.

weight
,

2
)
,

"
=>
"
,

size
(

l
.

weight
,

1
)
)

l
.

σ

==

identity

||

print
(
io
,

"
,
"
,

l
.

σ
)

l
.

bias

==

false

&&

print
(
io
,

"
; bias=false
"
)

print
(
io
,

"
)
"
)

end

Dense
(

W
::

LinearAlgebra
.

Diagonal
,

bias

=

true
,

σ

=

identity
)

=

Scale
(

W
.

diag
,

bias
,

σ
)

function

_size_check
(
layer
,

x
::
AbstractArray
,

(
d
,

n
)
::
Pair
)

d

>

0

||

throw
(

DimensionMismatch
(

string
(

"
layer
"
,

layer
,

"
expects ndims(input) >
"
,

ndims
(
x
)
-
d
,

"
, but got
"
,

summary
(
x
)
)
)
)

size
(
x
,

d
)

==

n

||

throw
(

DimensionMismatch
(

string
(

"
layer
"
,

layer
,

"
expects size(input,
\$
d
) ==
\$
n
, but got
"
,

summary
(
x
)
)
)
)

end

ChainRulesCore
.

@
non_differentiable

_size_check
(

::
Any
...
)

"""

Scale(size::Integer..., σ=identity; bias=true, init=ones32)

Scale(scale::AbstractArray, [bias, σ])

Create an element-wise layer, whose forward pass is given by:

y = σ.(scale .* x .+ bias)

This uses `.*` instead of matrix multiplication `*` of [`Dense`](@ref).

The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,

with `init=ones32` by default. You may specify the function `init`,

turn off trainable bias with `bias=false`, or provide the array(s) explicitly.

Used by [`LayerNorm`](@ref) with `affine=true`.

# Examples

```jldoctest

julia> a = Flux.Scale(2)

Scale(2)            # 4 parameters

julia> Flux.params(a)

Params([Float32[1.0, 1.0], Float32[0.0, 0.0]])

julia> a([1 2 3])

2×3 Matrix{Float32}:

1.0  2.0  3.0

1.0  2.0  3.0

julia> b = Flux.Scale([1 2 3 4], false, abs2)

Scale(1, 4, abs2; bias=false)  # 4 parameters

julia> b([1, 10])

2×4 Matrix{Int64}:

1    4    9    16

100  400  900  1600

julia> Flux.params(b)

Params([[1 2 3 4]])

```

"""

struct

Scale
{
F
,

A
<:
AbstractArray
,

B
}

scale
::
A

bias
::
B

σ
::
F

function

Scale
(

scale
::
A
,

bias
::
B

=

true
,

σ
::
F

=

identity
)

where

{

A
<:
AbstractArray
,

B
<:

Union
{
Bool
,

AbstractArray
}
,

F
}

b

=

create_bias
(
scale
,

bias
,

size
(
scale
)
...
)

new
{
F
,

A
,

typeof
(
b
)
}
(
scale
,

b
,

σ
)

end

end

Scale
(

s1
::
Integer
,

s23
::
Integer
...

;

bias

=

true
,

init

=

ones32
,

_act

=

identity
)

=

Scale
(

init
(
s1
,

s23
...
)
,

bias
,

_act
)

Scale
(

size_act
...

;

bias

=

true
,

init

=

ones32
)

=

Scale
(

size_act
[

1
:

end
-
1
]
...

;

bias
,

init
,

_act

=

size_act
[
end
]
)

@
functor

Scale

function

(

a
::

Scale
)
(

x
::
AbstractArray
)

σ

=

NNlib
.

fast_act
(

a
.

σ
,

x
)

# replaces tanh => tanh_fast, etc

σ
.

(

a
.

scale

.*

x

.+

a
.

bias
)

end

function

Base
.

show
(

io
::
IO
,

l
::

Scale
)

print
(
io
,

"
Scale(
"
,

join
(

size
(

l
.

scale
)
,

"
,
"
)
)

l
.

σ

==

identity

||

print
(
io
,

"
,
"
,

l
.

σ
)

l
.

bias

==

false

&&

print
(
io
,

"
; bias=false
"
)

print
(
io
,

"
)
"
)

end

"""

Maxout(layers...)

Maxout(f, n_alts)

This contains a number of internal layers, each of which receives the same input.

Its output is the elementwise maximum of the the internal layers' outputs.

Instead of defining layers individually, you can provide a zero-argument function

which constructs them, and the number to construct.

Maxout over linear dense layers satisfies the univeral approximation theorem.

See Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks"

[https://arxiv.org/abs/1302.4389](https://arxiv.org/abs/1302.4389).

# Examples

```jldoctest

julia> m = Maxout(x -> abs2.(x), x -> x .* 3);

julia> m([-2 -1 0 1 2])

1×5 Matrix{Int64}:

4  1  0  3  6

julia> m3 = Maxout(() -> Dense(5 => 7, tanh), 3)

Maxout(

Dense(5 => 7, tanh),                  # 42 parameters

Dense(5 => 7, tanh),                  # 42 parameters

Dense(5 => 7, tanh),                  # 42 parameters

)                   # Total: 6 arrays, 126 parameters, 888 bytes.

julia> Flux.outputsize(m3, (5, 11))

(7, 11)

```

"""

struct

Maxout
{

T
<:
Tuple
}

layers
::
T

end

Maxout
(

layers
...
)

=

Maxout
(
layers
)

Maxout
(

f
::
Function
,

n_alts
::
Integer
)

=

Maxout
(

(

f
(
)

for

_

in

1
:
n_alts
)
...
)

@
functor

Maxout

function

(

mo
::

Maxout
)
(

input
::
AbstractArray
)

# Perhaps surprisingly, pairwise max broadcast is often faster,

# even with Zygote. See #698 and #1794

mapreduce
(

f

->

f
(
input
)
,

(
acc
,

out
)

->

max
.

(
acc
,

out
)
,

mo
.

layers
)

end

function

Base
.

show
(

io
::
IO
,

mo
::

Maxout
)

print
(
io
,

"
Maxout(
"
)

_show_layers
(
io
,

mo
.

layers
)

print
(
io
,

"
)
"
)

end

"""

SkipConnection(layer, connection)

Create a skip connection which consists of a layer or `Chain` of consecutive

layers and a shortcut connection linking the block's input to the output

through a user-supplied 2-argument callable. The first argument to the callable

will be propagated through the given `layer` while the second is the unchanged,

"skipped" input.

The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`.

Here is a more complicated example:

```jldoctest

julia> m = Conv((3,3), 4 => 7, pad=(1,1));

julia> x = ones(Float32, 5, 5, 4, 10);

julia> size(m(x)) == (5, 5, 7, 10)

true

julia> sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3));

julia> size(sm(x)) == (5, 5, 11, 10)

true

```

"""

struct

SkipConnection
{
T
,
F
}

layers
::
T

connection
::
F

#user can pass arbitrary connections here, such as (a,b) -> a + b

end

@
functor

SkipConnection

function

(

skip
::

SkipConnection
)
(
input
)

skip
.

connection
(

skip
.

layers
(
input
)
,

input
)

end

function

Base
.

show
(

io
::
IO
,

b
::

SkipConnection
)

print
(
io
,

"
SkipConnection(
"
,

b
.

layers
,

"
,
"
,

b
.

connection
,

"
)
"
)

end

"""

Bilinear((in1, in2) => out, σ=identity; bias=true, init=glorot_uniform)

Bilinear(W::AbstractArray, [bias, σ])

Creates a layer which is fully connected between two inputs and the output, and otherwise similar to [`Dense`](@ref).

Its output, given vectors `x` & `y`, is another vector `z` with,

for all `i ∈ 1:out`:

z[i] = σ(x' * W[i,:,:] * y + bias[i])

If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form,

with `B` the Bilinear layer.

If the second input `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`

The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,

which is accepted as the input to a `Chain`.

If the two input sizes are the same, `in1 == in2`, then you may write `Bilinear(in => out, σ)`.

The initialisation works as for [`Dense`](@ref) layer, with `W = init(out, in1, in2)`.

By default the bias vector is `zeros(Float32, out)`, option `bias=false` will switch off

trainable bias. Either of these may be provided explicitly.

# Examples

```jldoctest

julia> x, y = randn(Float32, 5, 32), randn(Float32, 5, 32);

julia> B = Flux.Bilinear((5, 5) => 7)

Bilinear(5 => 7)    # 182 parameters

julia> B(x) |> size  # interactions based on one input

(7, 32)

julia> B(x,y) == B((x,y))  # two inputs, may be given as a tuple

true

julia> sc = SkipConnection(

Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)),

Flux.Bilinear((9, 5) => 3, bias=false),

);  # used as the recombinator, with skip as the second input

julia> sc(x) |> size

(3, 32)

julia> Flux.Bilinear(rand(4,8,16), false, tanh)  # first dim of weight is the output

Bilinear((8, 16) => 4, tanh; bias=false)  # 512 parameters

```

"""

struct

Bilinear
{
F
,
A
,
B
}

weight
::
A

bias
::
B

σ
::
F

function

Bilinear
(

W
::
A
,

bias

=

true
,

σ
::
F

=

identity
)

where

{

A
<:
AbstractArray
,

F
}

ndims
(
A
)

==

3

||

throw
(

ArgumentError
(

"
expected a 3-array of weights
"
)
)

b

=

create_bias
(
W
,

bias
,

size
(
W
,
1
)
)

new
{
F
,
A
,

typeof
(
b
)
}
(
W
,

b
,

σ
)

end

end

@
functor

Bilinear

function

Bilinear
(

(

(
in1
,

in2
)
,

out
)
::

Pair
{

<:
Tuple
,

<:
Integer
}
,

σ

=

identity

;

bias

=

true
,

init

=

glorot_uniform
)

Bilinear
(

init
(
out
,

in1
,

in2
)
,

bias
,

σ
)

end

Bilinear
(

(
in12
,

out
)
::

Pair
{

<:
Integer
,

<:
Integer
}
,

σ

=

identity

;

kw
...
)

=

Bilinear
(

(
in12
,

in12
)

=>

out
,

σ

;

kw
...
)

function

(

a
::

Bilinear
)
(

x
::
AbstractMatrix
,

y
::
AbstractMatrix
)

W
,

b
,

σ

=

a
.

weight
,

a
.

bias
,

a
.

σ

d_z
,

d_x
,

d_y

=

size
(
W
)

d_x

==

size
(
x
,
1
)

&&

d_y

==

size
(
y
,
1
)

||

throw
(

DimensionMismatch
(

"
number of rows in data must match W
"
)
)

size
(
x
,
2
)

==

size
(
y
,
2
)

||

throw
(

DimensionMismatch
(

"
Data inputs must agree on number of columns, got
\$
(

size
(
x
,
2
)
)
and
\$
(

size
(
y
,
2
)
)
"
)
)

# @einsum Wy[o,i,s] := W[o,i,j] * y[j,s]

Wy

=

reshape
(

reshape
(
W
,

(
:
,

d_y
)
)

*

y
,

(
d_z
,

d_x
,

:
)
)

# @einsum Z[o,s] := Wy[o,i,s] * x[i,s]

Wyx

=

batched_mul
(
Wy
,

reshape
(
x
,

(
d_x
,

1
,

:
)
)
)

Z

=

reshape
(
Wyx
,

(
d_z
,

:
)
)

# @einsum out[o,s] := σ(Z[o,i] + b[o])

σ
.

(

Z

.+

b
)

end

(

a
::

Bilinear
)
(

x
::
AbstractVecOrMat
)

=

a
(
x
,

x
)

(

a
::

Bilinear
)
(

x
::
AbstractVector
,

y
::
AbstractVector
)

=

vec
(

a
(

reshape
(
x
,

:
,
1
)
,

reshape
(
y
,

:
,
1
)
)
)

(

a
::

Bilinear
)
(

x
::

NTuple
{
2
,

AbstractArray
}
)

=

a
(

x
[
1
]
,

x
[
2
]
)

function

Base
.

show
(

io
::
IO
,

l
::

Bilinear
)

if

size
(

l
.

weight
,

2
)

==

size
(

l
.

weight
,

3
)

print
(
io
,

"
Bilinear(
"
,

size
(

l
.

weight
,

2
)
,

"
=>
"
,

size
(

l
.

weight
,

1
)
)

else

print
(
io
,

"
Bilinear((
"
,

size
(

l
.

weight
,

2
)
,

"
,
"
,

size
(

l
.

weight
,

3
)
,

"
) =>
"
,

size
(

l
.

weight
,

1
)
)

end

l
.

σ

==

identity

||

print
(
io
,

"
,
"
,

l
.

σ
)

l
.

bias

===

false

&&

print
(
io
,

"
; bias=false
"
)

print
(
io
,

"
)
"
)

end

"""

Parallel(connection, layers...)

Parallel(connection; name = layer, ...)

Create a layer which passes an input array to each path in

`layers`, before reducing the output with `connection`.

Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.

If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.

Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.

These can be accessed by indexing: `m[1] == m[:name]` is the first layer.

and [`Maxout`](@ref) which reduces by broadcasting `max`.

# Examples

```jldoctest

julia> model = Chain(Dense(3 => 5),

Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))),

Dense(8 => 17));

julia> model(rand32(3)) |> size

(17,)

julia> model2 = Parallel(+; α = Dense(10, 2, tanh), β = Dense(5, 2))

Parallel(

+,

α = Dense(10 => 2, tanh),             # 22 parameters

β = Dense(5 => 2),                    # 12 parameters

)                   # Total: 4 arrays, 34 parameters, 392 bytes.

julia> model2(rand32(10), rand32(5)) |> size

(2,)

julia> model2[:α](rand32(10)) |> size

(2,)

julia> model2[:β] == model2[2]

true

```

"""

struct

Parallel
{
F
,

T
<:

Union
{
Tuple
,

NamedTuple
}
}

connection
::
F

layers
::
T

end

Parallel
(
connection
,

layers
...
)

=

Parallel
(
connection
,

layers
)

function

Parallel
(
connection

;

kw
...
)

layers

=

NamedTuple
(
kw
)

if

:
layers

in

keys
(
layers
)

||

:
connection

in

keys
(
layers
)

throw
(

ArgumentError
(

"
a Parallel layer cannot have a named sub-layer called `connection` or `layers`
"
)
)

end

isempty
(
layers
)

&&

return

Parallel
(
connection
,

(
)
)

Parallel
(
connection
,

layers
)

end

@
functor

Parallel

(

m
::

Parallel
)
(
x
)

=

m
.

connection
(

map
(

f

->

f
(
x
)
,

Tuple
(

m
.

layers
)
)
...
)

(

m
::

Parallel
)
(

xs
::
Tuple
)

=

m
(

xs
...
)

function

_parallel_check
(
layers
,

xs
)

nl

=

length
(
layers
)

nx

=

length
(
xs
)

if

(

nl

!=

nx
)

throw
(

ArgumentError
(

"
Parallel with
\$
nl
sub-layers can take one input or
\$
nl
inputs, but got
\$
nx
inputs
"
)
)

end

end

ChainRulesCore
.

@
non_differentiable

_parallel_check
(
nl
,

nx
)

function

(

m
::

Parallel
)
(

xs
...
)

_parallel_check
(

m
.

layers
,

xs
)

m
.

connection
(

map
(
|>
,

xs
,

Tuple
(

m
.

layers
)
)
...
)

end

Base
.

getindex
(

m
::

Parallel
,

i
)

=

m
.

layers
[
i
]

Base
.

getindex
(

m
::

Parallel
,

i
::
AbstractVector
)

=

Parallel
(

m
.

connection
,

m
.

layers
[
i
]
)

Base
.

getindex
(

m
::

Parallel
{

<:
Any
,

<:
NamedTuple
}
,

i
::
AbstractVector
)

=

Parallel
(

m
.

connection
,

NamedTuple
{

keys
(
m
)
[
i
]
}
(

Tuple
(

m
.

layers
)
[
i
]
)
)

Base
.

keys
(

m
::

Parallel
)

=

keys
(

getfield
(
m
,

:
layers
)
)

function

Base
.

show
(

io
::
IO
,

m
::

Parallel
)

print
(
io
,

"
Parallel(
"
,

m
.

connection
,

"
,
"
)

_show_layers
(
io
,

m
.

layers
)

print
(
io
,

"
)
"
)

end

"""

PairwiseFusion(connection, layers...)

## Arguments

- `connection`: A function taking 2 inputs and combining them into a single output

- `layers`: The layers whose outputs are combined

## Inputs

This layer behaves differently based on input type:

1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,

then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`.

Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))`

may be drawn as:

```

x1 → layer1 → y1 ↘

connection → layer2 → y2 ↘

x2 ↗                          connection → layer3 → y3

x3 ↗

```

... or written as:

```julia

y1 = layer1(x1)

y2 = layer2(connection(x2, y1))

y3 = layer3(connection(x3, y2))

```

2. With just one input, each layer receives the same `x` combined with the previous output.

Thus `y = PairwiseFusion(connection, layers...)(x)` obeys:

```julia

y[1] == layers[1](x)

for i in 2:length(layers)

y[i] == connection(x, layers[i](y[i-1]))

end

```

## Returns

A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).

"""

struct

PairwiseFusion
{
F
,

T
<:

Union
{
Tuple
,

NamedTuple
}
}

connection
::
F

layers
::
T

end

PairwiseFusion
(
connection
,

layers
...
)

=

PairwiseFusion
(
connection
,

layers
)

function

PairwiseFusion
(
connection

;

kw
...
)

layers

=

NamedTuple
(
kw
)

if

:
layers

in

keys
(
layers
)

||

:
connection

in

keys
(
layers
)

throw
(

ArgumentError
(

"
a PairwiseFusion layer cannot have a named sub-layer called `connection` or `layers`
"
)
)

end

isempty
(
layers
)

&&

return

PairwiseFusion
(
connection
,

(
)
)

PairwiseFusion
(
connection
,

layers
)

end

function

_pairwise_check
(
x
,

layers
,

T
)

lx

=

length
(
x
)

N

=

length
(
layers
)

if

T

<:

Tuple

&&

lx

!=

N

throw
(

ArgumentError
(

"
PairwiseFusion with
\$
N
sub-layers can take one input or
\$
N
inputs, but got
\$
lx
inputs
"
)
)

end

end

ChainRulesCore
.

@
non_differentiable

_pairwise_check
(
lx
,

N
,

T
)

function

(

m
::

PairwiseFusion
)
(

x
::
T
)

where

{
T
}

_pairwise_check
(
x
,

m
.

layers
,

T
)

applypairwisefusion
(

m
.

layers
,

m
.

connection
,

x
)

end

(

m
::

PairwiseFusion
)
(

xs
...
)

=

m
(
xs
)

@
generated

function

applypairwisefusion
(

layers
::

Tuple
{

Vararg
{
Any
,
N
}
}
,

connection
,

x
::
T
)

where

{
N
,

T
}

y_symbols

=

[

gensym
(
)

for

_

in

1
:
(

N

+

1
)
]

getinput
(
i
)

=

T

<:

Tuple

?

:
(

x
[

\$
i
]
)

:

:
x

calls

=

[

:
(

\$
(

y_symbols
[

N

+

1
]
)

=

\$
(

getinput
(
1
)
)
)
]

for

i

in

1
:

N

-

1

push!
(
calls
,

quote

\$
(

y_symbols
[
i
]
)

=

layers
[

\$
i
]
(

\$
(

y_symbols
[

N

+

1
]
)
)

\$
(

y_symbols
[

N

+

1
]
)

=

connection
(

\$
(

y_symbols
[
i
]
)
,

\$
(

getinput
(

i

+

1
)
)
)

end
)

end

push!
(
calls
,

:
(

\$
(

y_symbols
[
N
]
)

=

layers
[

\$
N
]
(

\$
(

y_symbols
[

N

+

1
]
)
)
)
)

push!
(
calls
,

:
(

return

tuple
(

\$
(

Tuple
(

y_symbols
[

1
:
N
]
)
...
)
)
)
)

return

Expr
(

:
block
,

calls
...
)

end

applypairwisefusion
(

layers
::
NamedTuple
,

connection
,

x
)

=

applypairwisefusion
(

Tuple
(
layers
)
,

connection
,

x
)

@
functor

PairwiseFusion

Base
.

getindex
(

m
::

PairwiseFusion
,

i
)

=

m
.

layers
[
i
]

Base
.

getindex
(

m
::

PairwiseFusion
,

i
::
AbstractVector
)

=

PairwiseFusion
(

m
.

connection
,

m
.

layers
[
i
]
)

Base
.

getindex
(

m
::

PairwiseFusion
{

<:
Any
,

<:
NamedTuple
}
,

i
::
AbstractVector
)

=

PairwiseFusion
(

m
.

connection
,

NamedTuple
{

keys
(
m
)
[
i
]
}
(

Tuple
(

m
.

layers
)
[
i
]
)
)

Base
.

keys
(

m
::

PairwiseFusion
)

=

keys
(

getfield
(
m
,

:
layers
)
)

function

Base
.

show
(

io
::
IO
,

m
::

PairwiseFusion
)

print
(
io
,

"
PairwiseFusion(
"
,

m
.

connection
,

"
,
"
)

_show_layers
(
io
,

m
.

layers
)

print
(
io
,

"
)
"
)

end

"""

Embedding(in => out; init=randn32)

A lookup table that stores embeddings of dimension `out`

for a vocabulary of size `in`, as a trainable matrix.

This layer is often used to store word embeddings and retrieve them using indices.

The input to the layer can be a vocabulary index in `1:in`, an array of indices,

or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).

For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.

For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`.

# Examples

```jldoctest

julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22))

Embedding(26 => 4)  # 104 parameters

julia> emb(2)  # one column of e.weight (here not random!)

4-element Vector{Float32}:

0.0

22.0

0.0

0.0

julia> emb([3, 1, 20, 14, 4, 15, 7])  # vocabulary indices, in 1:26

4×7 Matrix{Float32}:

0.0  22.0  0.0  0.0   0.0  0.0  0.0

0.0   0.0  0.0  0.0   0.0  0.0  0.0

22.0   0.0  0.0  0.0   0.0  0.0  0.0

0.0   0.0  0.0  0.0  22.0  0.0  0.0

julia> ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n'))

true

julia> emb(rand(1:26, (10, 1, 12))) |> size  # three batch dimensions

(4, 10, 1, 12)

```

"""

struct

Embedding
{

W
<:
AbstractMatrix
}

weight
::
W

end

@
functor

Embedding

Embedding
(

(
in
,

out
)
::

Pair
{

<:
Integer
,

<:
Integer
}

;

init

=

randn32
)

=

Embedding
(

init
(
out
,

in
)
)

(

m
::

Embedding
)
(

x
::
Integer
)

=

m
.

weight
[
:
,

x
]

(

m
::

Embedding
)
(

x
::
AbstractVector
)

=

NNlib
.

gather
(

m
.

weight
,

x
)

(

m
::

Embedding
)
(

x
::
AbstractArray
)

=

reshape
(

m
(

vec
(
x
)
)
,

:
,

size
(
x
)
...
)

(

m
::

Embedding
)
(

x
::

AbstractVector
{
Bool
}
)

=

m
.

weight

*

x``````

usually OneHotVector

``````

(

m
::

Embedding
)
(

x
::

AbstractMatrix
{
Bool
}
)

=

m
.

weight

*

x``````

usually OneHotMatrix

``````

(

m
::

Embedding
)
(

x
::

AbstractArray
{
Bool
}
)

=

reshape
(

m
(

reshape
(
x
,

size
(
x
,
1
)
,

:
)
)
,

:
,

size
(
x
)
[

2
:
end
]
...
)

function

Base
.

show
(

io
::
IO
,

m
::

Embedding
)

print
(
io
,

"
Embedding(
"
,

size
(

m
.

weight
,

2
)
,

"
=>
"
,

size
(

m
.

weight
,

1
)
,

"
)
"
)

end

"""

_splitat(data::AbstractVector, at::AbstractVector{Int})

Partitions `data` into a vector of views.

Each index `i in at` specifies that a view starts with `data[i]`.

These indices must be strictly increasing, and start at `1`.

The resulting views do not overlap, and are never empty.

The last view always ends with `data[end]`.

### Example

```jldoctest

julia> Flux._splitat(collect('A':'Z'), [1, 3, 4, 13])

4-element Vector{SubArray{Char, 1, Vector{Char}, Tuple{UnitRange{Int64}}, true}}:

['A', 'B']

['C']

['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L']

['M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']

```

"""

function

_splitat
(

data
::
AbstractVector
,

at
::

AbstractVector
{

<:
Integer
}
)

at
[
begin
]

==

firstindex
(
data
)

||

throw
(

ArgumentError
(

"
The first element in `at` must be 1.
"
)
)

at
[
end
]

<=

lastindex
(
data
)

||

throw
(

ArgumentError
(

"
The last element in `at` must be at most the length of `data`.
"
)
)

issorted
(
at
,

lt

=

<=
)

||

throw
(

ArgumentError
(

"
`at` must be monotonically increasing with no duplicates.
"
)
)

iplus

=

vcat
(
at
,

lastindex
(
data
)
+
1
)

return

[

view
(
data
,

iplus
[
n
]
:
(

iplus
[

n
+
1
]
-
1
)
)

for

n

in

eachindex
(
at
)
]

end

"""

EmbeddingBag(in => out, reduction=mean; init=Flux.randn32)

A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`.

Differs from [`Embedding`](@ref) in that, instead of acting on a single vocabulary index,

it always acts a vector of indices which it calls a "bag".

Their individual embedding vectors are reduced to one, using `mean` or some other function.

Instead of acting on one "bag", such as `x::Vector{Int}`, the layer can also act on several:

* Acting on a vector of "bags", it produces a matrix whose columns are the reduced vectors.

More generally on `x::Array{Vector{Int}}`, its output is of size `(out, size(x)...)`.

* Any higher-rank array of integers is interpreted as a collection of "bags" each along the first dimension.

Thus the output is `mapslices(e, x; dims=1)` when `e::EmbeddingBag` and `x::Array{Int,N}`.

This method is more efficient, but requires that all "bags" have the same length.

* A vector of "bags" may also be produced by splitting a vector of indices at specified points.

For this case the layer takes two inputs, both vectors of integers. See details below.

The "bag" may equivalently be represented as a `OneHotMatrix`. A collection of these,

or one higher-rank `OneHotArray`, again produce a stack of embeddings. See details below.

# Examples

```jldoctest

julia> vocab_size = 26;  # embed into 3 dimensions, with non-random vectors:

julia> eb = EmbeddingBag(vocab_size => 3, init=Flux.identity_init(gain=100))

EmbeddingBag(26 => 3)  # 78 parameters

julia> eb([2])  # one bag of 1 item

3-element Vector{Float32}:

0.0

100.0

0.0

julia> eb([3,3,1])  # one bag of 3 items, one mean embedding

3-element Vector{Float32}:

33.333332

0.0

66.666664

julia> eb([[3,1,3], [2,1]])  # two bags

3×2 Matrix{Float32}:

33.3333  50.0

0.0     50.0

66.6667   0.0

julia> eb([1 1 1 1; 1 2 3 4])  # 4 bags each of 2 items, eachcol([1 1 1 1; 1 2 3 4])

3×4 Matrix{Float32}:

100.0  50.0  50.0  50.0

0.0  50.0   0.0   0.0

0.0   0.0  50.0   0.0

julia> eb(rand(1:26, 10, 5, 5)) |> size  # 25 bags each of 10 items

(3, 5, 5)

```

Another way to specify "many bags of many items" is to provide a vector `data` (each in `1:in`)

and a vector `at` stating where to split that up into "bags".

The first bag starts with `data[at[1]]`, the second at `data[at[2]]`, and so on,

with no overlaps and nothing left out (thus it requires `at[1]==1`).

```jldoctest

julia> data = [11, 1, 12, 2, 13, 3, 14];

julia> Flux._splitat(data, [1, 4]) |> println  # internal function, makes data[1:3], data[4:end]

[[11, 1, 12], [2, 13, 3, 14]]

julia> eb(data, [1, 4])  # two bags, of 3 and 4 items

3×2 Matrix{Float32}:

33.3333   0.0

0.0     25.0

0.0     25.0

```

Finally, each bag may also be also be represented as a [`OneHotMatrix`](@ref OneHotArrays.onehotbatch).

```jldoctest

julia> eb(Flux.onehotbatch("bba", 'a':'z'))  # same as [2,2,1], one bag of 3 items

3-element Vector{Float32}:

33.333332

66.666664

0.0

julia> eb([Flux.onehotbatch("bba", 'a':'z'), Flux.onehotbatch("cc", 'a':'z')])  # two bags

3×2 Matrix{Float32}:

33.3333    0.0

66.6667    0.0

0.0     100.0

```

"""

struct

EmbeddingBag
{
F
,

W
<:
AbstractMatrix
}

weight
::
W

reduction
::
F

end

@
functor

EmbeddingBag

EmbeddingBag
(

(
in
,

out
)
::

Pair
{

<:
Integer
,

<:
Integer
}
,

reduction
::
Function

=

mean

;

init

=

randn32
)

=

EmbeddingBag
(

init
(
out
,

in
)
,

reduction
)

EmbeddingBag
(

weight
::
AbstractMatrix
)

=

EmbeddingBag
(
weight
,

mean
)

(

m
::

EmbeddingBag
)
(

data
::
AbstractVector
,

at
::
AbstractVector
)

=

m
(

_splitat
(
data
,

at
)
)

(

m
::

EmbeddingBag
)
(

inds
::

AbstractArray
{

<:
Integer
}
)

=

dropdims
(

m
.

reduction
(

Embedding
(

m
.

weight
)
(
inds
)
,

dims
=
2
)
,

dims
=
2
)

(

m
::

EmbeddingBag
)
(

ind
::
Integer
)

=

error
(

"
EmbeddingBag expects an array of indices, not just one
"
)

(

m
::

EmbeddingBag
)
(

hot
::

AbstractArray
{
Bool
}
)

=

dropdims
(

m
.

reduction
(

Embedding
(

m
.

weight
)
(
hot
)
,

dims
=
2
)
,

dims
=
2
)

(

m
::

EmbeddingBag
)
(

hot
::

AbstractVector
{
Bool
}
)

=

error
(

"
EmbeddingBag not defined for a one-hot vector
"
)``````

These two could be stack(m, bags), but no AD support yet. (Gradient for weight quite inefficient here.)

``````

(

m
::

EmbeddingBag
)
(

bags
::

AbstractVector
{

<:
AbstractVector
}
)

=

reduce
(
hcat
,

m
.

(
bags
)
)

(

m
::

EmbeddingBag
)
(

bags
::

AbstractArray
{

<:
AbstractVector
}
)

=

reshape
(

m
(

vec
(
bags
)
)
,

:
,

size
(
bags
)
...
)

(

m
::

EmbeddingBag
)
(

bags
::

AbstractArray
{

<:

AbstractMatrix
{
Bool
}
}
)

=

reshape
(

reduce
(
hcat
,

m
.

(

vec
(
bags
)
)
)
,

:
,

size
(
bags
)
...
)

function

Base
.

show
(

io
::
IO
,

m
::

EmbeddingBag
)

print
(
io
,

"
EmbeddingBag(
"
,

size
(

m
.

weight
,

2
)
,

"
=>
"
,

size
(

m
.

weight
,

1
)
,

"
)
"
)

end``````