Neural Network primitives from NNlib.jl
Flux re-exports all of the functions exported by the NNlib package. This includes activation functions, described on their own page. Many of the functions on this page exist primarily as the internal implementation of Flux layer, but can also be used independently.
Attention
Primitives for the MultiHeadAttention
layer.
NNlib.dot_product_attention
— Functiondot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])
Multihead dot product attention used in transformer architectures.
The input arrays must have the first two dimensions given by the number of features and the sequence length, then an arbitrary number of batch dimensions or none.
Returns the attention output array of size (v_dim, q_len, batch_size...)
and the attention scores of size (kv_len, q_len, nheads, batch_size...)
.
See also dot_product_attention_scores
if you only need the attention scores.
Arguments
query
: Query array of size(qk_dim, q_len, batch_size...)
.key
: Key array of size(qk_dim, kv_len, batch_size...)
.value
: Value array of size(v_dim, kv_len, batch_size...)
.bias
: Eithernothing
or an array broadcastable to size(kv_len, q_len, nheads, batch_size)
. It will be added to the attention scores before applying the softmax. Defaultnothing
.fdrop
: A dropout function or layer to be applied on the attention scores right after the softmax. Defaultidentity
(no dropout).mask
: Eithernothing
or a boolean array broadcastable to size(kv_len, q_len, nheads, batch_size)
. The mask is applied to the attention scores just before the softmax. Seemake_causal_mask
fore creating causal masks. Defaultnothing
.nheads
: Number of heads to split the input arrays into. Default1
.
Examples
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
y, α = dot_product_attention(q, k, v)
NNlib.dot_product_attention_scores
— Functiondot_product_attention_scores(query, key, [bias]; [fdrop, mask])
Return the attention scores for the dot_product_attention
. Input arrays must have dimensions (num_features ÷ nheads, nheads, sequence_length, batch_size)
.
See dot_product_attention
for more details.
NNlib.make_causal_mask
— Functionmake_causal_mask(x, dims=2)
Return a boolean square matrix m
of the same type as x
and of side size(x, dims)
. Its elements are set such that m[i, j] == i ≤ j
.
Can be used to mask the attention scores in dot_product_attention
.
Softmax
Flux
's Flux.logitcrossentropy
uses NNlib.logsoftmax
internally.
NNlib.softmax
— Functionsoftmax(x; dims = 1)
Softmax turns input array x
into probability distributions that sum to 1 along the dimensions specified by dims
. It is semantically equivalent to the following:
softmax(x; dims = 1) = exp.(x) ./ sum(exp.(x), dims = dims)
with additional manipulations enhancing numerical stability.
For a matrix input x
it will by default (dims = 1
) treat it as a batch of vectors, with each column independent. Keyword dims = 2
will instead treat rows independently, and so on.
See also logsoftmax
.
Examples
julia> softmax([1, 2, 3])
3-element Vector{Float64}:
0.09003057317038046
0.24472847105479764
0.6652409557748218
julia> softmax([1 2 3; 2 2 2]) # dims=1
2×3 Matrix{Float64}:
0.268941 0.5 0.731059
0.731059 0.5 0.268941
julia> softmax([1 2 3; 2 2 2]; dims=2)
2×3 Matrix{Float64}:
0.0900306 0.244728 0.665241
0.333333 0.333333 0.333333
Note that, when used with Flux.jl, softmax
must not be passed to layers like Dense
which accept an activation function. The activation is broadcasted over the result, thus applies to individual numbers. But softmax
always needs to see the whole column.
julia> using Flux
julia> x = randn(Float32, 4, 4, 3, 13);
julia> model = Chain(Conv((4, 4), 3 => 8, tanh), Flux.flatten, Dense(8 => 7), softmax);
julia> model(x) |> size
(7, 13)
julia> Dense(4 => 7, softmax)(x)
ERROR: `softmax(x)` called with a number, but it expects an array.
NNlib.logsoftmax
— Functionlogsoftmax(x; dims = 1)
Computes the log of softmax in a more numerically stable way than directly taking log.(softmax(xs))
. Commonly used in computing cross entropy loss.
It is semantically equivalent to the following:
logsoftmax(x; dims = 1) = x .- log.(sum(exp.(x), dims = dims))
See also softmax
.
Pooling
Flux
's AdaptiveMaxPool
, AdaptiveMeanPool
, GlobalMaxPool
, GlobalMeanPool
, MaxPool
, and MeanPool
use NNlib.PoolDims
, NNlib.maxpool
, and NNlib.meanpool
as their backend.
NNlib.PoolDims
— TypePoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int};
stride=k, padding=0, dilation=1) where {M, L}
Dimensions for a "pooling" operation that can have an arbitrary input size, kernel size, stride, dilation, and channel count. Used to dispatch onto efficient implementations at compile-time.
NNlib.lpnormpool
— Functionlpnormpool(x, p::Real, k::NTuple{N, Integer}; pad=0, stride=k)
Perform Lp pool operation with value of the Lp norm p
and window size k
on input tensor x
, also known as LPPool in pytorch. This pooling operator from Learned-Norm Pooling for Deep Feedforward and Recurrent Neural Networks.
Arguments:
x
andk
: Expectsndim(x) ∈ 3:5
, and always
length(k) == ndim(x) - 2`p
is restricted to0 < p < Inf
.pad
: Seepad_zeros
for details.stride
: Either a tuple with the same length ask
, or one integer for all directions. Default isk
.
For all elements x
in a size k
window, lpnormpool computes (∑ᵢ xᵢ^p)^(1 / p)
as an element of the output.
Thus lpnormpool(x, 1, k) ./ prod(k) ≈ meanpool(x, k)
and lpnormpool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k)
.
NNlib.maxpool
— Functionmaxpool(x, k::NTuple{N, Integer}; pad=0, stride=k)
Perform max pool operation with window size k
on input tensor x
.
Arguments:
x
andk
: Expectsndim(x) ∈ 3:5
, and alwayslength(k) == ndim(x) - 2
pad
: Seepad_zeros
for details.stride
: Either a tuple with the same length ask
, or one integer for all directions. Default isk
.
NNlib.meanpool
— Functionmeanpool(x, k::NTuple{N, Integer}; pad=0, stride=k)
Perform mean pool operation with window size k
on input tensor x
.
Arguments:
x
andk
: Expectsndim(x) ∈ 3:5
, and always
length(k) == ndim(x) - 2`pad
: Seepad_zeros
for details.stride
: Either a tuple with the same length ask
, or one integer for all directions. Default isk
.
Padding
NNlib.pad_circular
— Functionpad_circular(x, pad::Tuple; [dims])
pad_circular(x, pad::Int; [dims])
Pad the array x
"circularly" across the border by wrapping around values from the opposite side of x
.
pad
can a tuple of integers (l1, r1, ..., ln, rn)
of some length 2n
that specifies the left and right padding size for each of the dimensions in dims
. If dims
is not given, it defaults to the first n
dimensions.
If pad
is an integer, it is applied on both sides on every dimension in dims
. In this case, dims
defaults to the first ndims(x)-2
dimensions (i.e. excludes the channel and batch dimension).
The pad length on either side in any dimension must not exceed the size of x
in that dimension, i.e. pad_circular
is not able to create abitrary sized tilings of x
.
See also pad_repeat
, pad_reflect
, pad_symmetric
, and pad_constant
.
julia> r = reshape(1:9, 3, 3)
3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:
1 4 7
2 5 8
3 6 9
julia> pad_circular(r, (1,2,1,2))
6×6 Matrix{Int64}:
9 3 6 9 3 6
7 1 4 7 1 4
8 2 5 8 2 5
9 3 6 9 3 6
7 1 4 7 1 4
8 2 5 8 2 5
NNlib.pad_constant
— Functionpad_constant(x, pad::Tuple, val = 0; [dims = :])
pad_constant(x, pad::Int, val = 0; [dims = :])
Pad the array x
with the constant value val
.
pad
can be a tuple of integers. If it is of some length 2 * length(dims)
that specifies the left and right padding size for each of the dimensions in dims
as (l1, r1, ..., ln, rn)
. If supplied with a tuple of length length(dims)
instead, it applies symmetric padding. If dims
is not given, it defaults to all dimensions.
For integer pad
input, it is applied on both sides on every dimension in dims
.
See also pad_zeros
, pad_repeat
, pad_reflect
, pad_symmetric
, and pad_circular
.
julia> r = reshape(1:4, 2, 2)
2×2 reshape(::UnitRange{Int64}, 2, 2) with eltype Int64:
1 3
2 4
julia> pad_constant(r, (1, 2, 3, 4), 8)
5×9 Matrix{Int64}:
8 8 8 8 8 8 8 8 8
8 8 8 1 3 8 8 8 8
8 8 8 2 4 8 8 8 8
8 8 8 8 8 8 8 8 8
8 8 8 8 8 8 8 8 8
julia> pad_constant(r, 1, 8)
4×4 Matrix{Int64}:
8 8 8 8
8 1 3 8
8 2 4 8
8 8 8 8
julia> r = reshape(1:27, 3, 3, 3)
3×3×3 reshape(::UnitRange{Int64}, 3, 3, 3) with eltype Int64:
[:, :, 1] =
1 4 7
2 5 8
3 6 9
[:, :, 2] =
10 13 16
11 14 17
12 15 18
[:, :, 3] =
19 22 25
20 23 26
21 24 27
julia> pad_constant(r, (2,1), dims = 1) # assymetric padding
6×3×3 Array{Int64, 3}:
[:, :, 1] =
0 0 0
0 0 0
1 4 7
2 5 8
3 6 9
0 0 0
[:, :, 2] =
0 0 0
0 0 0
10 13 16
11 14 17
12 15 18
0 0 0
[:, :, 3] =
0 0 0
0 0 0
19 22 25
20 23 26
21 24 27
0 0 0
julia> pad_constant(r, (2,1, 3), dims = (1,2)) # padding must always be either the same length as dims, or double it
ERROR: ArgumentError: Could not parse padding (2, 1, 3) and dims (1, 2)
Stacktrace:
[...]
NNlib.pad_reflect
— Functionpad_reflect(x, pad::Tuple; [dims])
pad_reflect(x, pad::Int; [dims])
Pad the array x
reflecting its values across the border.
pad
can a tuple of integers (l1, r1, ..., ln, rn)
of some length 2n
that specifies the left and right padding size for each of the dimensions in dims
. If dims
is not given, it defaults to the first n
dimensions.
If pad
is an integer, it is applied on both sides on every dimension in dims
. In this case, dims
defaults to the first ndims(x)-2
dimensions (i.e. excludes the channel and batch dimension).
See also pad_repeat
, pad_symmetric
, pad_circular
, and pad_constant
.
julia> r = reshape(1:9, 3, 3)
3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:
1 4 7
2 5 8
3 6 9
julia> pad_reflect(r, (1,2,1,2))
6×6 Matrix{Int64}:
5 2 5 8 5 2
4 1 4 7 4 1
5 2 5 8 5 2
6 3 6 9 6 3
5 2 5 8 5 2
4 1 4 7 4 1
NNlib.pad_repeat
— Functionpad_repeat(x, pad::Tuple; [dims])
pad_repeat(x, pad::Int; [dims])
Pad the array x
repeating the values on the border.
pad
can a tuple of integers (l1, r1, ..., ln, rn)
of some length 2n
that specifies the left and right padding size for each of the dimensions in dims
. If dims
is not given, it defaults to the first n
dimensions.
If pad
is an integer, it is applied on both sides on every dimension in dims
. In this case, dims
defaults to the first ndims(x)-2
dimensions (i.e. excludes the channel and batch dimension).
See also pad_reflect
, pad_symmetric
, pad_circular
, and pad_constant
.
julia> r = reshape(1:9, 3, 3)
3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:
1 4 7
2 5 8
3 6 9
julia> pad_repeat(r, (1,2,3,4))
6×10 Matrix{Int64}:
1 1 1 1 4 7 7 7 7 7
1 1 1 1 4 7 7 7 7 7
2 2 2 2 5 8 8 8 8 8
3 3 3 3 6 9 9 9 9 9
3 3 3 3 6 9 9 9 9 9
3 3 3 3 6 9 9 9 9 9
NNlib.pad_symmetric
— Functionpad_symmetric(x, pad::Tuple; [dims])
pad_symmetric(x, pad::Int; [dims])
Pad the array x
reflecting its values symmetrically across the border, i.e. the border values of x
are present in the padding values, in contrast to pad_reflect
.
pad
can a tuple of integers (l1, r1, ..., ln, rn)
of some length 2n
that specifies the left and right padding size for each of the dimensions in dims
. If dims
is not given, it defaults to the first n
dimensions.
If pad
is an integer, it is applied on both sides on every dimension in dims
. In this case, dims
defaults to the first ndims(x)-2
dimensions (i.e. excludes the channel and batch dimension).
See also pad_repeat
, pad_reflect
, pad_circular
, and pad_constant
.
julia> r = reshape(1:9, 3, 3)
3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:
1 4 7
2 5 8
3 6 9
julia> pad_symmetric(r, (1,2,1,2))
6×6 Matrix{Int64}:
1 1 4 7 7 4
1 1 4 7 7 4
2 2 5 8 8 5
3 3 6 9 9 6
3 3 6 9 9 6
2 2 5 8 8 5
NNlib.pad_zeros
— Functionpad_zeros(x, pad::Tuple; [dims])
pad_zeros(x, pad::Int; [dims])
Pad the array x
with zeros. Equivalent to pad_constant
with the constant equal to 0.
Convolution
Flux
's Conv
and CrossCor
layers use NNlib.DenseConvDims
and NNlib.conv
internally.
NNlib.conv
— Functionconv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1)
Apply convolution filter w
to input x
. x
and w
are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively. x
and w
may have real or complex element types.
NNlib.ConvDims
— TypeConvDims
Type system-level information about convolution dimensions. Critical for things like im2col!()
to generate efficient code, and helpful to reduce the number of kwargs getting passed around.
NNlib.depthwiseconv
— Functiondepthwiseconv(x, w; stride=1, pad=0, dilation=1, flipped=false)
Depthwise convolution operation with filter w
on input x
. x
and w
are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.
NNlib.DepthwiseConvDims
— TypeDepthwiseConvDims
Concrete subclass of ConvDims
for a depthwise convolution. Differs primarily due to characterization by Cin, Cmult, rather than Cin, Cout. Useful to be separate from DenseConvDims primarily for channel calculation differences.
NNlib.DenseConvDims
— TypeDenseConvDims
Concrete subclass of ConvDims
for a normal, dense, conv2d/conv3d.
Dropout
NNlib.dropout
— Functiondropout([rng], A, p; [dims])
Returns an array in which each element of A
is either replaced with zero, with probability p
, or else multiplied by 1/(1-p)
.
By default every element is treated independently. With keyword dims=1
, a choice is made for every value of the 1st index i.e. each row of a matrix is either zero or not.
Optional first argument is the random number generator used.
Examples
julia> dropout(ones(2, 10), 0.2)
2×10 Matrix{Float64}:
1.25 1.25 0.0 1.25 1.25 1.25 1.25 1.25 1.25 1.25
1.25 1.25 1.25 0.0 1.25 1.25 0.0 1.25 1.25 1.25
julia> mean(dropout(ones(10^4, 5), 0.2), dims=1)
1×5 Matrix{Float64}:
0.998 1.00075 0.99125 0.99575 1.00075
julia> dropout(ones(5, 5), 0.7, dims=1) # whole row the same
5×5 Matrix{Float64}:
3.33333 3.33333 3.33333 3.33333 3.33333
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
3.33333 3.33333 3.33333 3.33333 3.33333
0.0 0.0 0.0 0.0 0.0
julia> mean(dropout(ones(10^4, 5), 0.3, dims=1), dims=1)
1×5 Matrix{Float64}:
1.00571 1.00571 1.00571 1.00571 1.00571
NNlib.dropout!
— Functiondropout!(B, A, p; [dims])
This does exactly B .= dropout(A, p; dims)
, or rather, it's the implementation of out-of-place dropout
.
Upsampling
Flux
's Upsample
layer uses NNlib.upsample_nearest
, NNlib.upsample_bilinear
, and NNlib.upsample_trilinear
as its backend. Additionally, Flux
's PixelShuffle
layer uses NNlib.pixel_shuffle
as its backend.
NNlib.upsample_nearest
— Functionupsample_nearest(x, scale::NTuple{S,Int})
upsample_nearest(x; size::NTuple{S,Int})
Upsamples the array x
by integer multiples along the first S
dimensions. Subsequent dimensions of x
are not altered.
Either the scale
factors or the final output size
can be specified.
See also upsample_bilinear
, for two dimensions of an N=4
array.
Example
julia> upsample_nearest([1 2 3; 4 5 6], (2, 3))
4×9 Matrix{Int64}:
1 1 1 2 2 2 3 3 3
1 1 1 2 2 2 3 3 3
4 4 4 5 5 5 6 6 6
4 4 4 5 5 5 6 6 6
julia> ans == upsample_nearest([1 2 3; 4 5 6]; size=(4, 9)) # equivalent
true
julia> upsample_nearest([1 2 3; 4 5 6], (2,))
4×3 Matrix{Int64}:
1 2 3
1 2 3
4 5 6
4 5 6
julia> ans == upsample_nearest([1 2 3; 4 5 6], size=(4,))
true
NNlib.upsample_linear
— Functionupsample_linear(x::AbstractArray{T,3}, scale::Real; align_corners::Bool = true)
upsample_linear(x::AbstractArray{T,3}; size::Integer, align_corners::Bool = true)
Upsamples the first dimension of the array x
by the upsample provided scale
, using linear interpolation. As an alternative to using scale
, the resulting array size
can be directly specified with a keyword argument.
The size of the output is equal to (scale*S1, S2, S3)
, where S1, S2, S3 = size(x)
.
NNlib.∇upsample_linear
— Function∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer, align_corners::Bool = true) where T
Arguments
Δ
: Incoming gradient array, backpropagated from downstream layerssize
: Size of the image upsampled in the first place
Outputs
dx
: Downsampled version ofΔ
NNlib.upsample_bilinear
— Functionupsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}; align_corners::Bool = true)
upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}, align_corners::Bool = true)
Upsamples the first 2 dimensions of the array x
by the upsample factors stored in scale
, using bilinear interpolation. As an alternative to using scale
, the resulting image size
can be directly specified with a keyword argument.
The size of the output is equal to (scale[1]*S1, scale[2]*S2, S3, S4)
, where S1, S2, S3, S4 = size(x)
.
Examples
julia> x = reshape(Float32[1 2 3; 4 5 6], (2,3,1,1))
2×3×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
1.0 2.0 3.0
4.0 5.0 6.0
julia> upsample_bilinear(x, (2, 3))
4×9×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
1.0 1.25 1.5 1.75 2.0 2.25 2.5 2.75 3.0
2.0 2.25 2.5 2.75 3.0 3.25 3.5 3.75 4.0
3.0 3.25 3.5 3.75 4.0 4.25 4.5 4.75 5.0
4.0 4.25 4.5 4.75 5.0 5.25 5.5 5.75 6.0
julia> ans == upsample_bilinear(x; size=(4, 9)) # specify ouput size instead
true
julia> upsample_bilinear(x, (2.5, 3.5)) # non-integer scaling factors are allowed
5×10×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
1.0 1.22222 1.44444 1.66667 1.88889 … 2.33333 2.55556 2.77778 3.0
1.75 1.97222 2.19444 2.41667 2.63889 3.08333 3.30556 3.52778 3.75
2.5 2.72222 2.94444 3.16667 3.38889 3.83333 4.05556 4.27778 4.5
3.25 3.47222 3.69444 3.91667 4.13889 4.58333 4.80556 5.02778 5.25
4.0 4.22222 4.44444 4.66667 4.88889 5.33333 5.55556 5.77778 6.0
NNlib.∇upsample_bilinear
— Function∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}, align_corners::Bool = true) where T
Arguments
Δ
: Incoming gradient array, backpropagated from downstream layerssize
: Lateral (W,H) size of the image upsampled in the first place
Outputs
dx
: Downsampled version ofΔ
NNlib.upsample_trilinear
— Functionupsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real}; align_corners::Bool = true)
upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}, align_corners::Bool = true)
Upsamples the first 3 dimensions of the array x
by the upsample factors stored in scale
, using trilinear interpolation. As an alternative to using scale
, the resulting image size
can be directly specified with a keyword argument.
The size of the output is equal to (scale[1]*S1, scale[2]*S2, scale[3]*S3, S4, S5)
, where S1, S2, S3, S4, S5 = size(x)
.
Examples
upsample_trilinear(x, (2, 3, 4))
upsample_trilinear(x; size=(4, 9, 11)) # specify ouput size instead
upsample_trilinear(x, (2.5, 3.5, pi)) # non-integer scaling factors are allowed
NNlib.∇upsample_trilinear
— Function∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}, align_corners::Bool = true) where T
Arguments
Δ
: Incoming gradient array, backpropagated from downstream layerssize
: Lateral size & depth (W,H,D) of the image upsampled in the first place
Outputs
dx
: Downsampled version ofΔ
NNlib.pixel_shuffle
— Functionpixel_shuffle(x, r::Integer)
Pixel shuffling operation, upscaling by a factor r
.
For 4-arrays representing N
images, the operation converts input size(x) == (W, H, r^2*C, N)
to output of size (r*W, r*H, C, N)
. For D
-dimensional data, it expects ndims(x) == D+2
with channel and batch dimensions, and divides the number of channels by r^D
.
Used in super-resolution networks to upsample towards high resolution features. Reference: Shi et. al., "Real-Time Single Image and Video Super-Resolution ...", CVPR 2016, https://arxiv.org/abs/1609.05158
Examples
julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
2×3×4×1 Array{Float64, 4}:
[:, :, 1, 1] =
11.1 12.1 13.1
21.1 22.1 23.1
[:, :, 2, 1] =
11.2 12.2 13.2
21.2 22.2 23.2
[:, :, 3, 1] =
11.3 12.3 13.3
21.3 22.3 23.3
[:, :, 4, 1] =
11.4 12.4 13.4
21.4 22.4 23.4
julia> pixel_shuffle(x, 2) # 4 channels used up as 2x upscaling of image dimensions
4×6×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
11.1 11.3 12.1 12.3 13.1 13.3
11.2 11.4 12.2 12.4 13.2 13.4
21.1 21.3 22.1 22.3 23.1 23.3
21.2 21.4 22.2 22.4 23.2 23.4
julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1]
3×6×1 Array{Float64, 3}:
[:, :, 1] =
1.1 1.2 1.3 1.4 1.5 1.6
2.1 2.2 2.3 2.4 2.5 2.6
3.1 3.2 3.3 3.4 3.5 3.6
julia> pixel_shuffle(y, 2) # 1D image, with 6 channels reduced to 3
6×3×1 Array{Float64, 3}:
[:, :, 1] =
1.1 1.3 1.5
1.2 1.4 1.6
2.1 2.3 2.5
2.2 2.4 2.6
3.1 3.3 3.5
3.2 3.4 3.6
Batched Operations
Flux
's Flux.Bilinear
layer uses NNlib.batched_mul
internally.
NNlib.batched_mul
— Functionbatched_mul(A, B) -> C
A ⊠ B # \boxtimes
Batched matrix multiplication. Result has C[:,:,k...] == A[:,:,k...] * B[:,:,k...]
where k...
represent any indices in the last dimensions.
If ndims(A) == ndims(B) == 3
and size(B,3) == 1
then instead C[:,:,k] == A[:,:,k] * B[:,:,1]
, and similarly for A
.
To transpose each matrix, apply batched_transpose
to the array, or batched_adjoint
for conjugate-transpose:
julia> A, B = randn(2,5,17), randn(5,9,17);
julia> A ⊠ B |> size
(2, 9, 17)
julia> batched_adjoint(A) |> size
(5, 2, 17)
julia> batched_mul(A, batched_adjoint(randn(9,5,17))) |> size
(2, 9, 17)
julia> A ⊠ randn(5,9,1) |> size
(2, 9, 17)
julia> batched_transpose(A) == PermutedDimsArray(A, (2,1,3))
true
The equivalent PermutedDimsArray
may be used in place of batched_transpose
. Other permutations are also handled by BLAS, provided that the batch index k
is not the first dimension of the underlying array. Thus PermutedDimsArray(::Array, (1,3,2))
and PermutedDimsArray(::Array, (3,1,2))
are fine.
However, A = PermutedDimsArray(::Array, (3,2,1))
is not acceptable to BLAS, since the batch dimension is the contiguous one: stride(A,3) == 1
. This will be copied, as doing so is faster than batched_mul_generic!
.
Both this copy
and batched_mul_generic!
produce @debug
messages, and setting for instance ENV["JULIA_DEBUG"] = NNlib
will display them.
batched_mul(A::Array{T,3}, B::Matrix)
batched_mul(A::Matrix, B::Array{T,3})
A ⊠ B
This is always matrix-matrix multiplication, but either A
or B
may lack a batch index.
When
B
is a matrix, result hasC[:,:,k] == A[:,:,k] * B[:,:]
for allk
.When
A
is a matrix, thenC[:,:,k] == A[:,:] * B[:,:,k]
. This can also be done by reshaping and calling*
, for instanceA ⊡ B
using TensorCore.jl, but is implemented here usingbatched_gemm
instead ofgemm
.
julia> randn(16,8,32) ⊠ randn(8,4) |> size
(16, 4, 32)
julia> randn(16,8,32) ⊠ randn(8,4,1) |> size # equivalent
(16, 4, 32)
julia> randn(16,8) ⊠ randn(8,4,32) |> size
(16, 4, 32)
See also batched_vec
to regard B
as a batch of vectors, A[:,:,k] * B[:,k]
.
NNlib.batched_mul!
— Functionbatched_mul!(C, A, B) -> C
batched_mul!(C, A, B, α=1, β=0)
In-place batched matrix multiplication, equivalent to mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)
for all k
. If size(B,3) == 1
then every batch uses B[:,:,1]
instead.
This will call batched_gemm!
whenever possible. For real arrays this means that, for X ∈ [A,B,C]
, either stride(X,1)==1
or stride(X,2)==1
, the latter may be caused by batched_transpose
or by for instance PermutedDimsArray(::Array, (3,1,2))
. Unlike batched_mul
this will never make a copy.
For complex arrays, the wrapper made by batched_adjoint
must be outermost to be seen. In this case the strided accepted by BLAS are more restricted, if stride(C,1)==1
then only stride(AorB::BatchedAdjoint,2) == 1
is accepted.
NNlib.batched_adjoint
— Functionbatched_transpose(A::AbstractArray{T,3})
batched_adjoint(A)
Equivalent to applying transpose
or adjoint
to each matrix A[:,:,k]
.
These exist to control how batched_mul
behaves, as it operates on such matrix slices of an array with ndims(A)==3
.
PermutedDimsArray(A, (2,1,3))
is equivalent to batched_transpose(A)
, and is also understood by batched_mul
(and more widely supported elsewhere).
BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}
Lazy wrappers analogous to Transpose
and Adjoint
, returned by batched_transpose
etc.
NNlib.batched_transpose
— Functionbatched_transpose(A::AbstractArray{T,3})
batched_adjoint(A)
Equivalent to applying transpose
or adjoint
to each matrix A[:,:,k]
.
These exist to control how batched_mul
behaves, as it operates on such matrix slices of an array with ndims(A)==3
.
PermutedDimsArray(A, (2,1,3))
is equivalent to batched_transpose(A)
, and is also understood by batched_mul
(and more widely supported elsewhere).
BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}
Lazy wrappers analogous to Transpose
and Adjoint
, returned by batched_transpose
etc.
NNlib.batched_vec
— Functionbatched_vec(A::Array{T,3}, B::Matrix)
batched_vec(A::Array{T,3}, b::Vector)
Batched matrix-vector multiplication: the result has C[:,:,k] == A[:,:,k] * B[:,k]
for all k
, or else C[:,:,k] == A[:,:,k] * b
for b::Vector
.
With the same argument types, batched_mul(A, B)
would regard B
as a fixed matrix, not a batch of vectors. Both reshape and then call batched_mul(::Array{T,3}, ::Array{T,3})
.
julia> A, B, b = randn(16,8,32), randn(8,32), randn(8);
julia> batched_vec(A,B) |> size
(16, 32)
julia> batched_vec(A,b) |> size
(16, 32)
Gather and Scatter
Flux
's Embedding
layer uses NNlib.gather
as its backend.
NNlib.gather
— FunctionNNlib.gather(src, idx) -> dst
Reverse operation of scatter
. Gathers data from source src
and writes it in a destination dst
according to the index array idx
. For each k
in CartesianIndices(idx)
, assign values to dst
according to
dst[:, ... , k] .= src[:, ... , idx[k]...]
Notice that if idx
is a vector containing integers and src
is a matrix, previous expression simplifies to
dst[:, k] .= src[:, idx[k]]
and k
will run over 1:length(idx)
.
The elements of idx
can be integers or integer tuples and may be repeated. A single src
column can end up being copied into zero, one, or multiple dst
columns.
See gather!
for an in-place version.
Examples
julia> NNlib.gather([1,20,300,4000], [2,4,2])
3-element Vector{Int64}:
20
4000
20
julia> NNlib.gather([1 2 3; 4 5 6], [1,3,1,3,1])
2×5 Matrix{Int64}:
1 3 1 3 1
4 6 4 6 4
gather(src, IJK...)
Convert the tuple of integer vectors IJK
to a tuple of CartesianIndex
and call gather
on it: gather(src, CartesianIndex.(IJK...))
.
Examples
julia> src = reshape([1:15;], 3, 5)
3×5 Matrix{Int64}:
1 4 7 10 13
2 5 8 11 14
3 6 9 12 15
julia> NNlib.gather(src, [1, 2], [2, 4])
2-element Vector{Int64}:
4
11
NNlib.gather!
— FunctionNNlib.gather!(dst, src, idx)
Reverse operation of scatter!
. Gathers data from source src
and writes it in destination dst
according to the index array idx
. For each k
in CartesianIndices(idx)
, assign values to dst
according to
dst[:, ... , k] .= src[:, ... , idx[k]...]
Notice that if idx
is a vector containing integers, and both dst
and src
are matrices, previous expression simplifies to
dst[:, k] .= src[:, idx[k]]
and k
will run over 1:length(idx)
.
The elements of idx
can be integers or integer tuples and may be repeated. A single src
column can end up being copied into zero, one, or multiple dst
columns.
See gather
for an allocating version.
NNlib.scatter
— FunctionNNlib.scatter(op, src, idx; [init, dstsize])
Scatter operation allocating a destination array dst
and calling scatter!(op, dst, src, idx)
on it.
If keyword
init
is provided, it is used to initialize the content ofdst
. Otherwise, the init values is inferred from the reduction operatorop
for some common operators (e.g.init = 0
forop = +
).If
dstsize
is provided, it will be used to define the size of destination array, otherwise it will be inferred bysrc
andidx
.
See scatter!
for full details on how idx
works.
Examples
julia> NNlib.scatter(+, [10,100,1000], [3,1,2])
3-element Vector{Int64}:
100
1000
10
julia> NNlib.scatter(+, [1 2 3 4; 5 6 7 8], [2,1,1,5])
2×5 Matrix{Int64}:
5 1 0 0 4
13 5 0 0 8
julia> NNlib.scatter(*, [10,200,3000], [1,4,2]; init = 10, dstsize = 6)
6-element Vector{Int64}:
100
30000
10
2000
10
10
NNlib.scatter!
— FunctionNNlib.scatter!(op, dst, src, idx)
Scatter operation, which writes data in src
into dst
at locations idx
. A binary reduction operator op
is applied during the scatter. For each index k
in idx
, accumulates values in dst
according to
dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...])
Arguments
op
: Operations to be applied ondst
andsrc
, e.g.+
,-
,*
,/
,max
,min
andmean
.dst
: The destination forsrc
to aggregate to. This argument will be mutated.src
: The source data for aggregating.idx
: The mapping for aggregation from source (index) to destination (value). Theidx
array can contain either integers or tuples.
Examples
julia> NNlib.scatter!(+, ones(3), [10,100], [1,3])
3-element Vector{Float64}:
11.0
1.0
101.0
julia> NNlib.scatter!(*, fill(0.5, 2, 4), [1 10; 100 1000], [3,2])
2×4 Matrix{Float64}:
0.5 5.0 0.5 0.5
0.5 500.0 50.0 0.5
Sampling
NNlib.grid_sample
— Functiongrid_sample(input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros)
Given input
, compute output by sampling input
values at pixel locations from grid
. Uses bilinear interpolation to calculate output values.
This implementation assumes the extrema (-1
and 1
) are considered as referring to the center points of the input’s corner pixels (i.e. align corners is true
).
Arguments
input
: Input array in(W_in, H_in, C, N)
shape.grid
: Input grid in(2, W_out, H_out, N)
shape. Where for each(W_out, H_out, N)
grid contains(x, y)
coordinates that specify sampling locations normalized by theinput
shape.Therefore,
x
andy
should have values in[-1, 1]
range. For example,(x = -1, y = -1)
is the left-top pixel ofinput
, and(x = 1, y = 1)
is the right-bottom pixel ofinput
.Out-of-bound values are handled according to the
padding_mode
.padding_mode
: Out-of-bound padding.:zeros
to use0
for out-of-bound grid locations.:border
to use border values for out-of-bound grid locations. Default is:zeros
.
Returns
(W_out, H_out, C, N)
sampled grid from input
.
Examples
In the example below, grid contains two out-of-bound sampling locations, which are handled differently, depending on the padding_mode
.
julia> x = reshape(collect(1.0:4.0), (2, 2, 1, 1))
2×2×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
1.0 3.0
2.0 4.0
julia> grid = Array{Float64}(undef, 2, 3, 2, 1);
julia> grid[:, 1, 1, 1] .= (-3, -1);
julia> grid[:, 2, 1, 1] .= (0, -1);
julia> grid[:, 3, 1, 1] .= (1, -1);
julia> grid[:, 1, 2, 1] .= (-1, 1);
julia> grid[:, 2, 2, 1] .= (0, 1);
julia> grid[:, 3, 2, 1] .= (3, 1);
julia> grid_sample(x, grid; padding_mode=:zeros)
3×2×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
0.0 3.0
1.5 3.5
2.0 0.0
julia> grid_sample(x, grid; padding_mode=:border)
3×2×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
1.0 3.0
1.5 3.5
2.0 4.0
NNlib.∇grid_sample
— Function∇grid_sample(Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros) where T
Arguments
Δ
: Input gradient in(W_out, H_out, C, N)
shape (same as output of the primal computation).input
: Input from primal computation in(W_in, H_in, C, N)
shape.grid
: Grid from primal computation in(2, W_out, H_out, N)
shape.padding_mode
: Out-of-bound padding.:zeros
to use0
for out-of-bound grid locations.:border
to use border values for out-of-bound grid locations. Should be the same as in primal computation. Default is:zeros
.
Returns
dinput
(same shape as input
) and dgrid
(same shape as grid
) gradients.
Losses
NNlib.ctc_loss
— Functionctc_loss(ŷ, y)
Computes the connectionist temporal classification loss between ŷ
and y
. ŷ
must be a classes-by-time matrices, i.e., each row represents a class and each column represents a time step. Additionally, the logsoftmax
function will be applied to ŷ
, so ŷ
must be the raw activation values from the neural network and not, for example, the activations after being passed through a softmax
activation function. y
must be a 1D array of the labels associated with ŷ
. The blank label is assumed to be the last label category in ŷ
, so it is equivalent to size(ŷ, 1)
. Used for sequence-to-sequence classification problems such as speech recognition and handwriting recognition where the exact time-alignment of the output (e.g., letters) is not needed to solve the problem. See Graves et al. (2006) or Graves (2012) for mathematical details.
Miscellaneous
NNlib.logsumexp
— Functionlogsumexp(x; dims = :)
Computes log.(sum(exp.(x); dims))
in a numerically stable way. Without dims
keyword this returns a scalar.
See also logsoftmax
.
NNlib.glu
— Functionglu(x, dim = 1)
The gated linear unit from the "Language Modeling with Gated Convolutional Networks" paper.
Calculates a .* sigmoid(b)
, where x
is split in half along given dimension dim
to form a
and b
.