NNlib

Flux re-exports all of the functions exported by the NNlib package.

Activation Functions

Non-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call σ.(xs), relu.(xs) and so on.

NNlib.hardsigmoidFunction
hardσ(x) = max(0, min(1, (x + 3) / 6)

Piecewise linear approximation of sigmoid.

NNlib.leakyreluFunction
leakyrelu(x, a=0.01) = max(a*x, x)

Leaky Rectified Linear Unit activation function. You can also specify the coefficient explicitly, e.g. leakyrelu(x, 0.01).

NNlib.lishtFunction
lisht(x) = x * tanh(x)

Non-Parametric Linearly Scaled Hyperbolic Tangent Activation Function. See LiSHT

NNlib.logcoshFunction
logcosh(x)

Return log(cosh(x)) which is computed in a numerically stable way.

NNlib.logsigmoidFunction
logσ(x)

Return log(σ(x)) which is computed in a numerically stable way.

NNlib.rreluFunction
rrelu(x, l=1/8, u=1/3) = max(a*x, x)

a = randomly sampled from uniform distribution U(l, u)

Randomized Leaky Rectified Linear Unit activation function. You can also specify the bound explicitly, e.g. rrelu(x, 0.0, 1.0).

Softmax

NNlib.softmaxFunction
softmax(x; dims=1)

Softmax turns input array x into probability distributions that sum to 1 along the dimensions specified by dims. It is semantically equivalent to the following:

softmax(x; dims=1) = exp.(x) ./ sum(exp.(x), dims=dims)

with additional manipulations enhancing numerical stability.

For a matrix input x it will by default (dims=1) treat it as a batch of vectors, with each column independent. Keyword dims=2 will instead treat rows independently, etc...

See also logsoftmax.

Examples

julia> softmax([1, 2, 3])
3-element Array{Float64,1}:
  0.0900306
  0.244728
  0.665241

julia> softmax([1 2 3; 2 2 2])  # dims=1
2×3 Array{Float64,2}:
 0.268941  0.5  0.731059
 0.731059  0.5  0.268941

julia> softmax([1 2 3; 2 2 2]; dims=2)
2×3 Array{Float64,2}:
 0.0900306  0.244728  0.665241
 0.333333   0.333333  0.333333
NNlib.logsoftmaxFunction
logsoftmax(x; dims=1)

Computes the log of softmax in a more numerically stable way than directly taking log.(softmax(xs)). Commonly used in computing cross entropy loss.

It is semantically equivalent to the following:

logsoftmax(x; dims=1) = x .- log.(sum(exp.(x), dims=dims))

See also softmax.

Pooling

NNlib.maxpoolFunction
maxpool(x, k::NTuple; pad=0, stride=k)

Perform max pool operation with window size k on input tensor x.

NNlib.meanpoolFunction
meanpool(x, k::NTuple; pad=0, stride=k)

Perform mean pool operation with window size k on input tensor x.

Convolution

NNlib.convFunction
conv(x, w; stride=1, pad=0, dilation=1, flipped=false)

Apply convolution filter w to input x. x and w are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.

NNlib.depthwiseconvFunction
depthwiseconv(x, w; stride=1, pad=0, dilation=1, flipped=false)

Depthwise convolution operation with filter w on input x. x and w are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.

Upsampling

NNlib.upsample_nearestFunction
upsample_nearest(x, scale::NTuple{S,Int})
upsample_nearest(x; size::NTuple{S,Int})

Upsamples the array x by integer multiples along the first S dimensions. Subsequent dimensions of x are not altered.

Either the scale factors or the final output size can be specified.

See also upsample_bilinear, for two dimensions of an N=4 array.

Example

julia> upsample_nearest([1 2 3; 4 5 6], (2, 3))
4×9 Array{Int64,2}:
 1  1  1  2  2  2  3  3  3
 1  1  1  2  2  2  3  3  3
 4  4  4  5  5  5  6  6  6
 4  4  4  5  5  5  6  6  6

julia> ans == upsample_nearest([1 2 3; 4 5 6]; size=(4, 9))  # equivalent
true

julia> upsample_nearest([1 2 3; 4 5 6], (2,))
4×3 Array{Int64,1}:
 1  2  3
 1  2  3
 4  5  6
 4  5  6

julia> ans == upsample_nearest([1 2 3; 4 5 6], size=(4,))
true
NNlib.upsample_bilinearFunction
upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real})
upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer})

Upsamples the first 2 dimensions of the array x by the upsample factors stored in scale, using bilinear interpolation. As an alternative to using scale, the resulting image size can be directly specified with a keyword argument.

The size of the output is equal to (scale[1]*S1, scale[2]*S2, S3, S4), where S1, S2, S3, S4 = size(x).

Examples

julia> x = reshape(Float32[1 2 3; 4 5 6], (2,3,1,1))
2×3×1×1 Array{Float32,4}:
[:, :, 1, 1] =
 1.0  2.0  3.0
 4.0  5.0  6.0

julia> upsample_bilinear(x, (2, 3))
4×9×1×1 Array{Float32,4}:
[:, :, 1, 1] =
 1.0  1.25  1.5  1.75  2.0  2.25  2.5  2.75  3.0
 2.0  2.25  2.5  2.75  3.0  3.25  3.5  3.75  4.0
 3.0  3.25  3.5  3.75  4.0  4.25  4.5  4.75  5.0
 4.0  4.25  4.5  4.75  5.0  5.25  5.5  5.75  6.0

julia> ans == upsample_bilinear(x; size=(4, 9))  # specify ouput size instead
true

julia> upsample_bilinear(x, (2.5, 3.5))  # non-integer scaling factors are allowed
5×10×1×1 Array{Float32,4}:
[:, :, 1, 1] =
 1.0   1.22222  1.44444  1.66667  1.88889  2.11111  2.33333  2.55556  2.77778  3.0
 1.75  1.97222  2.19444  2.41667  2.63889  2.86111  3.08333  3.30556  3.52778  3.75
 2.5   2.72222  2.94444  3.16667  3.38889  3.61111  3.83333  4.05556  4.27778  4.5
 3.25  3.47222  3.69444  3.91667  4.13889  4.36111  4.58333  4.80556  5.02778  5.25
 4.0   4.22222  4.44444  4.66667  4.88889  5.11111  5.33333  5.55556  5.77778  6.0
NNlib.pixel_shuffleFunction
pixel_shuffle(x, r::Integer)

Pixel shuffling operation, upscaling by a factor r.

For 4-arrays representing N images, the operation converts input size(x) == (W, H, r^2*C, N) to output of size (r*W, r*H, C, N). For D-dimensional data, it expects ndims(x) == D+2 with channel and batch dimensions, and divides the number of channels by r^D.

Used in super-resolution networks to upsample towards high resolution features. Reference: Shi et. al., "Real-Time Single Image and Video Super-Resolution ...", CVPR 2016, https://arxiv.org/abs/1609.05158

Examples

julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
2×3×4×1 Array{Float64,4}:
[:, :, 1, 1] =
 11.1  12.1  13.1
 21.1  22.1  23.1

[:, :, 2, 1] =
 11.2  12.2  13.2
 21.2  22.2  23.2

[:, :, 3, 1] =
 11.3  12.3  13.3
 21.3  22.3  23.3

[:, :, 4, 1] =
 11.4  12.4  13.4
 21.4  22.4  23.4

julia> pixel_shuffle(x, 2)  # 4 channels used up as 2x upscaling of image dimensions
4×6×1×1 Array{Float64,4}:
[:, :, 1, 1] =
 11.1  11.3  12.1  12.3  13.1  13.3
 11.2  11.4  12.2  12.4  13.2  13.4
 21.1  21.3  22.1  22.3  23.1  23.3
 21.2  21.4  22.2  22.4  23.2  23.4

julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1]
3×6×1 Array{Float64, 3}:
[:, :, 1] =
 1.1  1.2  1.3  1.4  1.5  1.6
 2.1  2.2  2.3  2.4  2.5  2.6
 3.1  3.2  3.3  3.4  3.5  3.6

julia> pixel_shuffle(y, 2)  # 1D image, with 6 channels reduced to 3
6×3×1 Array{Float64,3}:
[:, :, 1] =
 1.1  1.3  1.5
 1.2  1.4  1.6
 2.1  2.3  2.5
 2.2  2.4  2.6
 3.1  3.3  3.5
 3.2  3.4  3.6

Batched Operations

NNlib.batched_mulFunction
batched_mul(A, B) -> C
A ⊠ B  # \boxtimes

Batched matrix multiplication. Result has C[:,:,k] == A[:,:,k] * B[:,:,k] for all k. If size(B,3) == 1 then instead C[:,:,k] == A[:,:,k] * B[:,:,1], and similarly for A.

To transpose each matrix, apply batched_transpose to the array, or batched_adjoint for conjugate-transpose:

julia> A, B = randn(2,5,17), randn(5,9,17);

julia> A ⊠ B |> size
(2, 9, 17)

julia> batched_adjoint(A) |> size
(5, 2, 17)

julia> batched_mul(A, batched_adjoint(randn(9,5,17))) |> size
(2, 9, 17)

julia> A ⊠ randn(5,9,1) |> size
(2, 9, 17)

julia> batched_transpose(A) == PermutedDimsArray(A, (2,1,3))
true

The equivalent PermutedDimsArray may be used in place of batched_transpose. Other permutations are also handled by BLAS, provided that the batch index k is not the first dimension of the underlying array. Thus PermutedDimsArray(::Array, (1,3,2)) and PermutedDimsArray(::Array, (3,1,2)) are fine.

However, A = PermutedDimsArray(::Array, (3,2,1)) is not acceptable to BLAS, since the batch dimension is the contiguous one: stride(A,3) == 1. This will be copied, as doing so is faster than batched_mul_generic!.

Both this copy and batched_mul_generic! produce @debug messages, and setting for instance ENV["JULIA_DEBUG"] = NNlib will display them.

batched_mul(A::Array{T,3}, B::Matrix)
batched_mul(A::Matrix, B::Array{T,3})
A ⊠ B

This is always matrix-matrix multiplication, but either A or B may lack a batch index.

  • When B is a matrix, result has C[:,:,k] == A[:,:,k] * B[:,:] for all k.

  • When A is a matrix, then C[:,:,k] == A[:,:] * B[:,:,k]. This can also be done by reshaping and calling *, for instance A ⊡ B using TensorCore.jl, but is implemented here using batched_gemm instead of gemm.

julia> randn(16,8,32) ⊠ randn(8,4) |> size
(16, 4, 32)

julia> randn(16,8,32) ⊠ randn(8,4,1) |> size  # equivalent
(16, 4, 32)

julia> randn(16,8) ⊠ randn(8,4,32) |> size
(16, 4, 32)

See also batched_vec to regard B as a batch of vectors, A[:,:,k] * B[:,k].

NNlib.batched_mul!Function
batched_mul!(C, A, B) -> C
batched_mul!(C, A, B, α=1, β=0)

In-place batched matrix multiplication, equivalent to mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β) for all k. If size(B,3) == 1 then every batch uses B[:,:,1] instead.

This will call batched_gemm! whenever possible. For real arrays this means that, for X ∈ [A,B,C], either strides(X,1)==1 or strides(X,2)==1, the latter may be caused by batched_transpose or by for instance PermutedDimsArray(::Array, (3,1,2)). Unlike batched_mul this will never make a copy.

For complex arrays, the wrapper made by batched_adjoint must be outermost to be seen. In this case the strided accepted by BLAS are more restricted, if stride(C,1)==1 then only stride(AorB::BatchedAdjoint,2) == 1 is accepted.

NNlib.batched_adjointFunction
batched_transpose(A::AbstractArray{T,3})
batched_adjoint(A)

Equivalent to applying transpose or adjoint to each matrix A[:,:,k].

These exist to control how batched_mul behaves, as it operates on such matrix slices of an array with ndims(A)==3.

PermutedDimsArray(A, (2,1,3)) is equivalent to batched_transpose(A), and is also understood by batched_mul (and more widely supported elsewhere).

BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}

Lazy wrappers analogous to Transpose and Adjoint, returned by batched_transpose etc.

NNlib.batched_transposeFunction
batched_transpose(A::AbstractArray{T,3})
batched_adjoint(A)

Equivalent to applying transpose or adjoint to each matrix A[:,:,k].

These exist to control how batched_mul behaves, as it operates on such matrix slices of an array with ndims(A)==3.

PermutedDimsArray(A, (2,1,3)) is equivalent to batched_transpose(A), and is also understood by batched_mul (and more widely supported elsewhere).

BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}

Lazy wrappers analogous to Transpose and Adjoint, returned by batched_transpose etc.