Neural Network primitives from NNlib.jl

Flux re-exports all of the functions exported by the NNlib package. This includes activation functions, described on the next page. Many of the functions on this page exist primarily as the internal implementation of Flux layer, but can also be used independently.

Softmax

Flux's logitcrossentropy uses NNlib.softmax internally.

NNlib.softmax — Function
softmax(x; dims = 1)

Softmax turns input array x into probability distributions that sum to 1 along the dimensions specified by dims. It is semantically equivalent to the following:

softmax(x; dims = 1) = exp.(x) ./ sum(exp.(x), dims = dims)

with additional manipulations enhancing numerical stability.

For a matrix input x it will by default (dims = 1) treat it as a batch of vectors, with each column independent. Keyword dims = 2 will instead treat rows independently, and so on.

See also logsoftmax.

Examples

julia> softmax([1, 2, 3])
3-element Vector{Float64}:
 0.09003057317038046
 0.24472847105479764
 0.6652409557748218

julia> softmax([1 2 3; 2 2 2])  # dims=1
2×3 Matrix{Float64}:
 0.268941  0.5  0.731059
 0.731059  0.5  0.268941

julia> softmax([1 2 3; 2 2 2]; dims=2)
2×3 Matrix{Float64}:
 0.0900306  0.244728  0.665241
 0.333333   0.333333  0.333333

Note that, when used with Flux.jl, softmax must not be passed to layers like Dense which accept an activation function. The activation is broadcasted over the result, thus applies to individual numbers. But softmax always needs to see the whole column.

julia> using Flux

julia> x = randn(Float32, 4, 4, 3, 13);

julia> model = Chain(Conv((4, 4), 3 => 8, tanh), Flux.flatten, Dense(8 => 7), softmax);

julia> model(x) |> size
(7, 13)

julia> Dense(4 => 7, softmax)(x)
ERROR: `softmax(x)` called with a number, but it expects an array. 
NNlib.logsoftmax — Function
logsoftmax(x; dims = 1)

Computes the log of softmax in a more numerically stable way than directly taking log.(softmax(xs)). Commonly used in computing cross entropy loss.

It is semantically equivalent to the following:

logsoftmax(x; dims = 1) = x .- log.(sum(exp.(x), dims = dims))

See also softmax.

Pooling

Flux's AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, and MeanPool use NNlib.PoolDims, NNlib.maxpool, and NNlib.meanpool as their backend.

NNlib.PoolDims — Type
PoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int};
        stride=k, padding=0, dilation=1)  where {M, L}

Dimensions for a "pooling" operation that can have an arbitrary input size, kernel size, stride, dilation, and channel count. Used to dispatch onto efficient implementations at compile-time.

NNlib.maxpool — Function
maxpool(x, k::NTuple; pad=0, stride=k)

Perform max pool operation with window size k on input tensor x.

NNlib.meanpool — Function
meanpool(x, k::NTuple; pad=0, stride=k)

Perform mean pool operation with window size k on input tensor x.

Padding

NNlib.pad_reflect — Function
pad_reflect(x, pad::Tuple; [dims])
pad_reflect(x, pad::Int; [dims])

Pad the array x reflecting its values across the border.

pad can a tuple of integers (l1, r1, ..., ln, rn) of some length 2n that specifies the left and right padding size for each of the dimensions in dims. If dims is not given, it defaults to the first n dimensions.

For integer pad input instead, it is applied on both sides on every dimension in dims. In this case, dims defaults to the first ndims(x)-2 dimensions (i.e. excludes the channel and batch dimension).

See also pad_repeat and pad_constant.

julia> r = reshape(1:9, 3, 3)
3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:
 1  4  7
 2  5  8
 3  6  9

julia> pad_reflect(r, (1,2,1,2))
6×6 Matrix{Int64}:
 5  2  5  8  5  2
 4  1  4  7  4  1
 5  2  5  8  5  2
 6  3  6  9  6  3
 5  2  5  8  5  2
 4  1  4  7  4  1
NNlib.pad_constant — Function
pad_constant(x, pad::Tuple, val = 0; [dims = :])
pad_constant(x, pad::Int, val = 0; [dims = :])

Pad the array x with the constant value val.

pad can be a tuple of integers. If it is of some length 2 * length(dims) that specifies the left and right padding size for each of the dimensions in dims as (l1, r1, ..., ln, rn). If supplied with a tuple of length length(dims) instead, it applies symmetric padding. If dims is not given, it defaults to all dimensions.

For integer pad input, it is applied on both sides on every dimension in dims.

See also pad_zeros, pad_reflect and pad_repeat.

julia> r = reshape(1:4, 2, 2)
2×2 reshape(::UnitRange{Int64}, 2, 2) with eltype Int64:
 1  3
 2  4

julia> pad_constant(r, (1, 2, 3, 4), 8)
5×9 Matrix{Int64}:
 8  8  8  8  8  8  8  8  8
 8  8  8  1  3  8  8  8  8
 8  8  8  2  4  8  8  8  8
 8  8  8  8  8  8  8  8  8
 8  8  8  8  8  8  8  8  8

julia> pad_constant(r, 1, 8)
4×4 Matrix{Int64}:
 8  8  8  8
 8  1  3  8
 8  2  4  8
 8  8  8  8

julia> r = reshape(1:27, 3, 3, 3)
3×3×3 reshape(::UnitRange{Int64}, 3, 3, 3) with eltype Int64:
[:, :, 1] =
 1  4  7
 2  5  8
 3  6  9

[:, :, 2] =
 10  13  16
 11  14  17
 12  15  18

[:, :, 3] =
 19  22  25
 20  23  26
 21  24  27

julia> pad_constant(r, (2,1), dims = 1) # assymetric padding
6×3×3 Array{Int64, 3}:
[:, :, 1] =
 0  0  0
 0  0  0
 1  4  7
 2  5  8
 3  6  9
 0  0  0

[:, :, 2] =
  0   0   0
  0   0   0
 10  13  16
 11  14  17
 12  15  18
  0   0   0

[:, :, 3] =
  0   0   0
  0   0   0
 19  22  25
 20  23  26
 21  24  27
  0   0   0

julia> pad_constant(r, (2,1, 3), dims = (1,2)) # padding must always be either the same length as dims, or double it
ERROR: ArgumentError: Could not parse padding (2, 1, 3) and dims (1, 2)
Stacktrace:
[...]
NNlib.pad_repeat — Function
pad_repeat(x, pad::Tuple; [dims])
pad_repeat(x, pad::Int; [dims])

Pad the array x repeating the values on the border.

pad can a tuple of integers (l1, r1, ..., ln, rn) of some length 2n that specifies the left and right padding size for each of the dimensions in dims. If dims is not given, it defaults to the first n dimensions.

For integer pad input instead, it is applied on both sides on every dimension in dims. In this case, dims defaults to the first ndims(x)-2 dimensions (i.e. excludes the channel and batch dimension).

See also pad_reflect and pad_constant.

julia> r = reshape(1:9, 3, 3)
3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:
 1  4  7
 2  5  8
 3  6  9

julia> pad_repeat(r, (1,2,3,4))
6×10 Matrix{Int64}:
 1  1  1  1  4  7  7  7  7  7
 1  1  1  1  4  7  7  7  7  7
 2  2  2  2  5  8  8  8  8  8
 3  3  3  3  6  9  9  9  9  9
 3  3  3  3  6  9  9  9  9  9
 3  3  3  3  6  9  9  9  9  9
NNlib.pad_zeros — Function
pad_zeros(x, pad::Tuple; [dims])
pad_zeros(x, pad::Int; [dims])

Pad the array x with zeros. Equivalent to pad_constant with the constant equal to 0.

Convolution

Flux's Conv and CrossCor layers use NNlib.DenseConvDims and NNlib.conv internally.

NNlib.conv — Function
conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1)

Apply convolution filter w to input x. x and w are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.

NNlib.ConvDims — Type
ConvDims

Type system-level information about convolution dimensions. Critical for things like im2col!() to generate efficient code, and helpful to reduce the number of kwargs getting passed around.

NNlib.depthwiseconv — Function
depthwiseconv(x, w; stride=1, pad=0, dilation=1, flipped=false)

Depthwise convolution operation with filter w on input x. x and w are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.

NNlib.DepthwiseConvDims — Type
DepthwiseConvDims

Concrete subclass of ConvDims for a depthwise convolution. Differs primarily due to characterization by Cin, Cmult, rather than Cin, Cout. Useful to be separate from DenseConvDims primarily for channel calculation differences.

NNlib.DenseConvDims — Type
DenseConvDims

Concrete subclass of ConvDims for a normal, dense, conv2d/conv3d.

Upsampling

Flux's Upsample layer uses NNlib.upsample_nearest, NNlib.upsample_bilinear, and NNlib.upsample_trilinear as its backend. Additionally, Flux's PixelShuffle layer uses NNlib.pixel_shuffle as its backend.

NNlib.upsample_nearest — Function
upsample_nearest(x, scale::NTuple{S,Int})
upsample_nearest(x; size::NTuple{S,Int})

Upsamples the array x by integer multiples along the first S dimensions. Subsequent dimensions of x are not altered.

Either the scale factors or the final output size can be specified.

See also upsample_bilinear, for two dimensions of an N=4 array.

Example

julia> upsample_nearest([1 2 3; 4 5 6], (2, 3))
4×9 Matrix{Int64}:
 1  1  1  2  2  2  3  3  3
 1  1  1  2  2  2  3  3  3
 4  4  4  5  5  5  6  6  6
 4  4  4  5  5  5  6  6  6

julia> ans == upsample_nearest([1 2 3; 4 5 6]; size=(4, 9))  # equivalent
true

julia> upsample_nearest([1 2 3; 4 5 6], (2,))
4×3 Matrix{Int64}:
 1  2  3
 1  2  3
 4  5  6
 4  5  6

julia> ans == upsample_nearest([1 2 3; 4 5 6], size=(4,))
true
Missing docstring.

Missing docstring for ∇upsample_nearest. Check Documenter's build log for details.

NNlib.upsample_linear — Function
upsample_linear(x::AbstractArray{T,3}, scale::Real)
upsample_linear(x::AbstractArray{T,3}; size::Integer)

Upsamples the first dimension of the array x by the upsample provided scale, using linear interpolation. As an alternative to using scale, the resulting array size can be directly specified with a keyword argument.

The size of the output is equal to (scale*S1, S2, S3), where S1, S2, S3 = size(x).

NNlib.∇upsample_linear — Function
∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer) where T

Arguments

  • Δ: Incoming gradient array, backpropagated from downstream layers
  • size: Size of the image upsampled in the first place

Outputs

  • dx: Downsampled version of Δ
NNlib.upsample_bilinear — Function
upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real})
upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer})

Upsamples the first 2 dimensions of the array x by the upsample factors stored in scale, using bilinear interpolation. As an alternative to using scale, the resulting image size can be directly specified with a keyword argument.

The size of the output is equal to (scale[1]*S1, scale[2]*S2, S3, S4), where S1, S2, S3, S4 = size(x).

Examples

julia> x = reshape(Float32[1 2 3; 4 5 6], (2,3,1,1))
2×3×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
 1.0  2.0  3.0
 4.0  5.0  6.0

julia> upsample_bilinear(x, (2, 3))
4×9×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
 1.0  1.25  1.5  1.75  2.0  2.25  2.5  2.75  3.0
 2.0  2.25  2.5  2.75  3.0  3.25  3.5  3.75  4.0
 3.0  3.25  3.5  3.75  4.0  4.25  4.5  4.75  5.0
 4.0  4.25  4.5  4.75  5.0  5.25  5.5  5.75  6.0

julia> ans == upsample_bilinear(x; size=(4, 9))  # specify ouput size instead
true

julia> upsample_bilinear(x, (2.5, 3.5))  # non-integer scaling factors are allowed
5×10×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
 1.0   1.22222  1.44444  1.66667  1.88889  …  2.33333  2.55556  2.77778  3.0
 1.75  1.97222  2.19444  2.41667  2.63889     3.08333  3.30556  3.52778  3.75
 2.5   2.72222  2.94444  3.16667  3.38889     3.83333  4.05556  4.27778  4.5
 3.25  3.47222  3.69444  3.91667  4.13889     4.58333  4.80556  5.02778  5.25
 4.0   4.22222  4.44444  4.66667  4.88889     5.33333  5.55556  5.77778  6.0
NNlib.∇upsample_bilinear — Function
∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T

Arguments

  • Δ: Incoming gradient array, backpropagated from downstream layers
  • size: Lateral (W,H) size of the image upsampled in the first place

Outputs

  • dx: Downsampled version of Δ
NNlib.upsample_trilinear — Function
upsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real})
upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer})

Upsamples the first 3 dimensions of the array x by the upsample factors stored in scale, using trilinear interpolation. As an alternative to using scale, the resulting image size can be directly specified with a keyword argument.

The size of the output is equal to (scale[1]*S1, scale[2]*S2, scale[3]*S3, S4, S5), where S1, S2, S3, S4, S5 = size(x).

Examples

upsample_trilinear(x, (2, 3, 4))
upsample_trilinear(x; size=(4, 9, 11))  # specify ouput size instead
upsample_trilinear(x, (2.5, 3.5, pi))  # non-integer scaling factors are allowed
NNlib.∇upsample_trilinear — Function
∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}) where T

Arguments

  • Δ: Incoming gradient array, backpropagated from downstream layers
  • size: Lateral size & depth (W,H,D) of the image upsampled in the first place

Outputs

  • dx: Downsampled version of Δ
NNlib.pixel_shuffle — Function
pixel_shuffle(x, r::Integer)

Pixel shuffling operation, upscaling by a factor r.

For 4-arrays representing N images, the operation converts input size(x) == (W, H, r^2*C, N) to output of size (r*W, r*H, C, N). For D-dimensional data, it expects ndims(x) == D+2 with channel and batch dimensions, and divides the number of channels by r^D.

Used in super-resolution networks to upsample towards high resolution features. Reference: Shi et. al., "Real-Time Single Image and Video Super-Resolution ...", CVPR 2016, https://arxiv.org/abs/1609.05158

Examples

julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
2×3×4×1 Array{Float64, 4}:
[:, :, 1, 1] =
 11.1  12.1  13.1
 21.1  22.1  23.1

[:, :, 2, 1] =
 11.2  12.2  13.2
 21.2  22.2  23.2

[:, :, 3, 1] =
 11.3  12.3  13.3
 21.3  22.3  23.3

[:, :, 4, 1] =
 11.4  12.4  13.4
 21.4  22.4  23.4

julia> pixel_shuffle(x, 2)  # 4 channels used up as 2x upscaling of image dimensions
4×6×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
 11.1  11.3  12.1  12.3  13.1  13.3
 11.2  11.4  12.2  12.4  13.2  13.4
 21.1  21.3  22.1  22.3  23.1  23.3
 21.2  21.4  22.2  22.4  23.2  23.4

julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1]
3×6×1 Array{Float64, 3}:
[:, :, 1] =
 1.1  1.2  1.3  1.4  1.5  1.6
 2.1  2.2  2.3  2.4  2.5  2.6
 3.1  3.2  3.3  3.4  3.5  3.6

julia> pixel_shuffle(y, 2)  # 1D image, with 6 channels reduced to 3
6×3×1 Array{Float64, 3}:
[:, :, 1] =
 1.1  1.3  1.5
 1.2  1.4  1.6
 2.1  2.3  2.5
 2.2  2.4  2.6
 3.1  3.3  3.5
 3.2  3.4  3.6

Batched Operations

Flux's Bilinear layer uses NNlib.batched_mul internally.

NNlib.batched_mul — Function
batched_mul(A, B) -> C
A ⊠ B  # \boxtimes

Batched matrix multiplication. Result has C[:,:,k] == A[:,:,k] * B[:,:,k] for all k. If size(B,3) == 1 then instead C[:,:,k] == A[:,:,k] * B[:,:,1], and similarly for A.

To transpose each matrix, apply batched_transpose to the array, or batched_adjoint for conjugate-transpose:

julia> A, B = randn(2,5,17), randn(5,9,17);

julia> A ⊠ B |> size
(2, 9, 17)

julia> batched_adjoint(A) |> size
(5, 2, 17)

julia> batched_mul(A, batched_adjoint(randn(9,5,17))) |> size
(2, 9, 17)

julia> A ⊠ randn(5,9,1) |> size
(2, 9, 17)

julia> batched_transpose(A) == PermutedDimsArray(A, (2,1,3))
true

The equivalent PermutedDimsArray may be used in place of batched_transpose. Other permutations are also handled by BLAS, provided that the batch index k is not the first dimension of the underlying array. Thus PermutedDimsArray(::Array, (1,3,2)) and PermutedDimsArray(::Array, (3,1,2)) are fine.

However, A = PermutedDimsArray(::Array, (3,2,1)) is not acceptable to BLAS, since the batch dimension is the contiguous one: stride(A,3) == 1. This will be copied, as doing so is faster than batched_mul_generic!.

Both this copy and batched_mul_generic! produce @debug messages, and setting for instance ENV["JULIA_DEBUG"] = NNlib will display them.

batched_mul(A::Array{T,3}, B::Matrix)
batched_mul(A::Matrix, B::Array{T,3})
A ⊠ B

This is always matrix-matrix multiplication, but either A or B may lack a batch index.

  • When B is a matrix, result has C[:,:,k] == A[:,:,k] * B[:,:] for all k.

  • When A is a matrix, then C[:,:,k] == A[:,:] * B[:,:,k]. This can also be done by reshaping and calling *, for instance A ⊡ B using TensorCore.jl, but is implemented here using batched_gemm instead of gemm.

julia> randn(16,8,32) ⊠ randn(8,4) |> size
(16, 4, 32)

julia> randn(16,8,32) ⊠ randn(8,4,1) |> size  # equivalent
(16, 4, 32)

julia> randn(16,8) ⊠ randn(8,4,32) |> size
(16, 4, 32)

See also batched_vec to regard B as a batch of vectors, A[:,:,k] * B[:,k].

NNlib.batched_mul! — Function
batched_mul!(C, A, B) -> C
batched_mul!(C, A, B, α=1, β=0)

In-place batched matrix multiplication, equivalent to mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β) for all k. If size(B,3) == 1 then every batch uses B[:,:,1] instead.

This will call batched_gemm! whenever possible. For real arrays this means that, for X ∈ [A,B,C], either strides(X,1)==1 or strides(X,2)==1, the latter may be caused by batched_transpose or by for instance PermutedDimsArray(::Array, (3,1,2)). Unlike batched_mul this will never make a copy.

For complex arrays, the wrapper made by batched_adjoint must be outermost to be seen. In this case the strided accepted by BLAS are more restricted, if stride(C,1)==1 then only stride(AorB::BatchedAdjoint,2) == 1 is accepted.

NNlib.batched_adjoint — Function
batched_transpose(A::AbstractArray{T,3})
batched_adjoint(A)

Equivalent to applying transpose or adjoint to each matrix A[:,:,k].

These exist to control how batched_mul behaves, as it operates on such matrix slices of an array with ndims(A)==3.

PermutedDimsArray(A, (2,1,3)) is equivalent to batched_transpose(A), and is also understood by batched_mul (and more widely supported elsewhere).

BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}

Lazy wrappers analogous to Transpose and Adjoint, returned by batched_transpose etc.

NNlib.batched_transpose — Function
batched_transpose(A::AbstractArray{T,3})
batched_adjoint(A)

Equivalent to applying transpose or adjoint to each matrix A[:,:,k].

These exist to control how batched_mul behaves, as it operates on such matrix slices of an array with ndims(A)==3.

PermutedDimsArray(A, (2,1,3)) is equivalent to batched_transpose(A), and is also understood by batched_mul (and more widely supported elsewhere).

BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}

Lazy wrappers analogous to Transpose and Adjoint, returned by batched_transpose etc.

NNlib.batched_vec — Function
batched_vec(A::Array{T,3}, B::Matrix)
batched_vec(A::Array{T,3}, b::Vector)

Batched matrix-vector multiplication: the result has C[:,:,k] == A[:,:,k] * B[:,k] for all k, or else C[:,:,k] == A[:,:,k] * b for b::Vector.

With the same argument types, batched_mul(A, B) would regard B as a fixed matrix, not a batch of vectors. Both reshape and then call batched_mul(::Array{T,3}, ::Array{T,3}).

julia> A, B, b = randn(16,8,32), randn(8,32), randn(8);

julia> batched_vec(A,B) |> size
(16, 32)

julia> batched_vec(A,b) |> size
(16, 32)

Gather and Scatter

Flux's Embedding layer uses NNlib.gather as its backend.

NNlib.gather — Function
NNlib.gather(src, idx) -> dst

Reverse operation of scatter. Gathers data from source src and writes it in a destination dst according to the index array idx. For each k in CartesianIndices(idx), assign values to dst according to

dst[:, ... , k] .= src[:, ... , idx[k]...]

Notice that if idx is a vector containing integers and src is a matrix, previous expression simplifies to

dst[:, k] .= src[:, idx[k]]

and k will run over 1:length(idx).

The elements of idx can be integers or integer tuples and may be repeated. A single src column can end up being copied into zero, one, or multiple dst columns.

See gather! for an in-place version.

Examples

julia> NNlib.gather([1,20,300,4000], [2,4,2])
3-element Vector{Int64}:
   20
 4000
   20

julia> NNlib.gather([1 2 3; 4 5 6], [1,3,1,3,1])
2×5 Matrix{Int64}:
 1  3  1  3  1
 4  6  4  6  4
NNlib.gather! — Function
NNlib.gather!(dst, src, idx)

Reverse operation of scatter!. Gathers data from source src and writes it in destination dst according to the index array idx. For each k in CartesianIndices(idx), assign values to dst according to

dst[:, ... , k] .= src[:, ... , idx[k]...]

Notice that if idx is a vector containing integers, and both dst and src are matrices, previous expression simplifies to

dst[:, k] .= src[:, idx[k]]

and k will run over 1:length(idx).

The elements of idx can be integers or integer tuples and may be repeated. A single src column can end up being copied into zero, one, or multiple dst columns.

See gather for an allocating version.

NNlib.scatter — Function
NNlib.scatter(op, src, idx; [init, dstsize])

Scatter operation allocating a destination array dst and calling scatter!(op, dst, src, idx) on it.

  • If keyword init is provided, it is used to initialize the content of dst. Otherwise, the init values is inferred from the reduction operator op for some common operators (e.g. init = 0 for op = +).

  • If dstsize is provided, it will be used to define the size of destination array, otherwise it will be inferred by src and idx.

See scatter! for full details on how idx works.

Examples

julia> NNlib.scatter(+, [10,100,1000], [3,1,2])
3-element Vector{Int64}:
  100
 1000
   10

julia> NNlib.scatter(+, [1 2 3 4; 5 6 7 8], [2,1,1,5])
2×5 Matrix{Int64}:
  5  1  0  0  4
 13  5  0  0  8

julia> NNlib.scatter(*, [10,200,3000], [1,4,2]; init = 10, dstsize = 6)
6-element Vector{Int64}:
   100
 30000
    10
  2000
    10
    10
NNlib.scatter! — Function
NNlib.scatter!(op, dst, src, idx)

Scatter operation, which writes data in src into dst at locations idx. A binary reduction operator op is applied during the scatter. For each index k in idx, accumulates values in dst according to

dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...])

See also scatter, gather.

Arguments

  • op: Operations to be applied on dst and src, e.g. +, -, *, /, max, min and mean.
  • dst: The destination for src to aggregate to. This argument will be mutated.
  • src: The source data for aggregating.
  • idx: The mapping for aggregation from source (index) to destination (value). The idx array can contain either integers or tuples.

Examples

julia> NNlib.scatter!(+, ones(3), [10,100], [1,3])
3-element Vector{Float64}:
  11.0
   1.0
 101.0

julia> NNlib.scatter!(*, fill(0.5, 2, 4), [1 10; 100 1000], [3,2])
2×4 Matrix{Float64}:
 0.5    5.0   0.5  0.5
 0.5  500.0  50.0  0.5

Sampling

NNlib.grid_sample — Function
grid_sample(input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros)

Given input, compute output by sampling input values at pixel locations from grid. Uses bilinear interpolation to calculate output values.

This implementation assumes the extrema (-1 and 1) are considered as referring to the center points of the input’s corner pixels (i.e. align corners is true).

Arguments

  • input: Input array in (W_in, H_in, C, N) shape.

  • grid: Input grid in (2, W_out, H_out, N) shape. Where for each (W_out, H_out, N) grid contains (x, y) coordinates that specify sampling locations normalized by the input shape.

    Therefore, x and y should have values in [-1, 1] range. For example, (x = -1, y = -1) is the left-top pixel of input, and (x = 1, y = 1) is the right-bottom pixel of input.

    Out-of-bound values are handled according to the padding_mode.

  • padding_mode: Out-of-bound padding. :zeros to use 0 for out-of-bound grid locations. :border to use border values for out-of-bound grid locations. Default is :zeros.

Returns

(W_out, H_out, C, N) sampled grid from input.

Examples

In the example below, grid contains two out-of-bound sampling locations, which are handled differently, depending on the padding_mode.

julia> x = reshape(collect(1.0:4.0), (2, 2, 1, 1))
2×2×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
 1.0  3.0
 2.0  4.0

julia> grid = Array{Float64}(undef, 2, 3, 2, 1);

julia> grid[:, 1, 1, 1] .= (-3, -1);

julia> grid[:, 2, 1, 1] .= (0, -1);

julia> grid[:, 3, 1, 1] .= (1, -1);

julia> grid[:, 1, 2, 1] .= (-1, 1);

julia> grid[:, 2, 2, 1] .= (0, 1);

julia> grid[:, 3, 2, 1] .= (3, 1);

julia> grid_sample(x, grid; padding_mode=:zeros)
3×2×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
 0.0  3.0
 1.5  3.5
 2.0  0.0

julia> grid_sample(x, grid; padding_mode=:border)
3×2×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
 1.0  3.0
 1.5  3.5
 2.0  4.0
NNlib.∇grid_sample — Function
∇grid_sample(Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros) where T

Arguments

  • Δ: Input gradient in (W_out, H_out, C, N) shape (same as output of the primal computation).
  • input: Input from primal computation in (W_in, H_in, C, N) shape.
  • grid: Grid from primal computation in (2, W_out, H_out, N) shape.
  • padding_mode: Out-of-bound padding. :zeros to use 0 for out-of-bound grid locations. :border to use border values for out-of-bound grid locations. Should be the same as in primal computation. Default is :zeros.

Returns

dinput (same shape as input) and dgrid (same shape as grid) gradients.

Losses

NNlib.ctc_loss — Function
ctc_loss(ŷ, y)

Computes the connectionist temporal classification loss between ŷ and y. ŷ must be a classes-by-time matrices, i.e., each row represents a class and each column represents a time step. Additionally, the logsoftmax function will be applied to ŷ, so ŷ must be the raw activation values from the neural network and not, for example, the activations after being passed through a softmax activation function. y must be a 1D array of the labels associated with ŷ. The blank label is assumed to be the last label category in ŷ, so it is equivalent to size(ŷ, 1). Used for sequence-to-sequence classification problems such as speech recognition and handwriting recognition where the exact time-alignment of the output (e.g., letters) is not needed to solve the problem. See Graves et al. (2006) or Graves (2012) for mathematical details.

Miscellaneous

NNlib.logsumexp — Function
logsumexp(x; dims = :)

Computes log.(sum(exp.(x); dims)) in a numerically stable way. Without dims keyword this returns a scalar.

See also logsoftmax.