NNlib
Flux re-exports all of the functions exported by the NNlib package.
Activation Functions
Non-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call σ.(xs), relu.(xs) and so on.
NNlib.celu — Functioncelu(x, α=1) = x ≥ 0 ? x : α * (exp(x/α) - 1)Continuously Differentiable Exponential Linear Units See Continuously Differentiable Exponential Linear Units.
NNlib.elu — Functionelu(x, α=1) = x > 0 ? x : α * (exp(x) - 1)Exponential Linear Unit activation function. See Fast and Accurate Deep Network Learning by Exponential Linear Units. You can also specify the coefficient explicitly, e.g. elu(x, 1).
NNlib.gelu — Functiongelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))Gaussian Error Linear Unit activation function.
NNlib.hardsigmoid — Functionhardσ(x) = max(0, min(1, (x + 3) / 6)Piecewise linear approximation of sigmoid.
NNlib.hardtanh — Functionhardtanh(x) = max(-1, min(1, x))Segment-wise linear approximation of tanh. Cheaper and more computational efficient version of tanh. See Large Scale Machine Learning.
NNlib.leakyrelu — Functionleakyrelu(x, a=0.01) = max(a*x, x)Leaky Rectified Linear Unit activation function. You can also specify the coefficient explicitly, e.g. leakyrelu(x, 0.01).
NNlib.lisht — Functionlisht(x) = x * tanh(x)Non-Parametric Linearly Scaled Hyperbolic Tangent Activation Function. See LiSHT
NNlib.logcosh — Functionlogcosh(x)Return log(cosh(x)) which is computed in a numerically stable way.
NNlib.logsigmoid — Functionlogσ(x)Return log(σ(x)) which is computed in a numerically stable way.
NNlib.mish — Functionmish(x) = x * tanh(softplus(x))Self Regularized Non-Monotonic Neural Activation Function. See Mish: A Self Regularized Non-Monotonic Neural Activation Function.
NNlib.relu — Functionrelu(x) = max(0, x)Rectified Linear Unit activation function.
NNlib.relu6 — Functionrelu6(x) = min(max(0, x), 6)Rectified Linear Unit activation function capped at 6. See Convolutional Deep Belief Networks on CIFAR-10
NNlib.rrelu — Functionrrelu(x, l=1/8, u=1/3) = max(a*x, x)
a = randomly sampled from uniform distribution U(l, u)Randomized Leaky Rectified Linear Unit activation function. You can also specify the bound explicitly, e.g. rrelu(x, 0.0, 1.0).
NNlib.selu — Functionselu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1))
λ ≈ 1.05070...
α ≈ 1.67326...Scaled exponential linear units. See Self-Normalizing Neural Networks.
NNlib.sigmoid — Functionσ(x) = 1 / (1 + exp(-x))Classic sigmoid activation function.
NNlib.softplus — Functionsoftplus(x) = log(exp(x) + 1)NNlib.softshrink — Functionsoftshrink(x, λ=0.5) =
(x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0))NNlib.softsign — Functionsoftsign(x) = x / (1 + |x|)NNlib.swish — Functionswish(x) = x * σ(x)Self-gated activation function. See Swish: a Self-Gated Activation Function.
NNlib.tanhshrink — Functiontanhshrink(x) = x - tanh(x)NNlib.trelu — Functiontrelu(x, theta=1) = x > theta ? x : 0Threshold Gated Rectified Linear. See ThresholdRelu
Softmax
NNlib.softmax — Functionsoftmax(x; dims=1)Softmax turns input array x into probability distributions that sum to 1 along the dimensions specified by dims. It is semantically equivalent to the following:
softmax(x; dims=1) = exp.(x) ./ sum(exp.(x), dims=dims)with additional manipulations enhancing numerical stability.
For a matrix input x it will by default (dims=1) treat it as a batch of vectors, with each column independent. Keyword dims=2 will instead treat rows independently, etc...
See also logsoftmax.
Examples
julia> softmax([1, 2, 3])
3-element Array{Float64,1}:
0.0900306
0.244728
0.665241
julia> softmax([1 2 3; 2 2 2]) # dims=1
2×3 Array{Float64,2}:
0.268941 0.5 0.731059
0.731059 0.5 0.268941
julia> softmax([1 2 3; 2 2 2]; dims=2)
2×3 Array{Float64,2}:
0.0900306 0.244728 0.665241
0.333333 0.333333 0.333333NNlib.logsoftmax — Functionlogsoftmax(x; dims=1)Computes the log of softmax in a more numerically stable way than directly taking log.(softmax(xs)). Commonly used in computing cross entropy loss.
It is semantically equivalent to the following:
logsoftmax(x; dims=1) = x .- log.(sum(exp.(x), dims=dims))See also softmax.
Pooling
NNlib.maxpool — Functionmaxpool(x, k::NTuple; pad=0, stride=k)Perform max pool operation with window size k on input tensor x.
NNlib.meanpool — Functionmeanpool(x, k::NTuple; pad=0, stride=k)Perform mean pool operation with window size k on input tensor x.
Convolution
NNlib.conv — Functionconv(x, w; stride=1, pad=0, dilation=1, flipped=false)Apply convolution filter w to input x. x and w are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.
NNlib.depthwiseconv — Functiondepthwiseconv(x, w; stride=1, pad=0, dilation=1, flipped=false)Depthwise convolution operation with filter w on input x. x and w are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.
Upsampling
NNlib.upsample_nearest — Functionupsample_nearest(x, scale::NTuple{S,Int})
upsample_nearest(x; size::NTuple{S,Int})Upsamples the array x by integer multiples along the first S dimensions. Subsequent dimensions of x are not altered.
Either the scale factors or the final output size can be specified.
See also upsample_bilinear, for two dimensions of an N=4 array.
Example
julia> upsample_nearest([1 2 3; 4 5 6], (2, 3))
4×9 Array{Int64,2}:
1 1 1 2 2 2 3 3 3
1 1 1 2 2 2 3 3 3
4 4 4 5 5 5 6 6 6
4 4 4 5 5 5 6 6 6
julia> ans == upsample_nearest([1 2 3; 4 5 6]; size=(4, 9)) # equivalent
true
julia> upsample_nearest([1 2 3; 4 5 6], (2,))
4×3 Array{Int64,1}:
1 2 3
1 2 3
4 5 6
4 5 6
julia> ans == upsample_nearest([1 2 3; 4 5 6], size=(4,))
trueNNlib.upsample_bilinear — Functionupsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real})
upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer})Upsamples the first 2 dimensions of the array x by the upsample factors stored in scale, using bilinear interpolation. As an alternative to using scale, the resulting image size can be directly specified with a keyword argument.
The size of the output is equal to (scale[1]*S1, scale[2]*S2, S3, S4), where S1, S2, S3, S4 = size(x).
Examples
julia> x = reshape(Float32[1 2 3; 4 5 6], (2,3,1,1))
2×3×1×1 Array{Float32,4}:
[:, :, 1, 1] =
1.0 2.0 3.0
4.0 5.0 6.0
julia> upsample_bilinear(x, (2, 3))
4×9×1×1 Array{Float32,4}:
[:, :, 1, 1] =
1.0 1.25 1.5 1.75 2.0 2.25 2.5 2.75 3.0
2.0 2.25 2.5 2.75 3.0 3.25 3.5 3.75 4.0
3.0 3.25 3.5 3.75 4.0 4.25 4.5 4.75 5.0
4.0 4.25 4.5 4.75 5.0 5.25 5.5 5.75 6.0
julia> ans == upsample_bilinear(x; size=(4, 9)) # specify ouput size instead
true
julia> upsample_bilinear(x, (2.5, 3.5)) # non-integer scaling factors are allowed
5×10×1×1 Array{Float32,4}:
[:, :, 1, 1] =
1.0 1.22222 1.44444 1.66667 1.88889 2.11111 2.33333 2.55556 2.77778 3.0
1.75 1.97222 2.19444 2.41667 2.63889 2.86111 3.08333 3.30556 3.52778 3.75
2.5 2.72222 2.94444 3.16667 3.38889 3.61111 3.83333 4.05556 4.27778 4.5
3.25 3.47222 3.69444 3.91667 4.13889 4.36111 4.58333 4.80556 5.02778 5.25
4.0 4.22222 4.44444 4.66667 4.88889 5.11111 5.33333 5.55556 5.77778 6.0NNlib.pixel_shuffle — Functionpixel_shuffle(x, r::Integer)Pixel shuffling operation, upscaling by a factor r.
For 4-arrays representing N images, the operation converts input size(x) == (W, H, r^2*C, N) to output of size (r*W, r*H, C, N). For D-dimensional data, it expects ndims(x) == D+2 with channel and batch dimensions, and divides the number of channels by r^D.
Used in super-resolution networks to upsample towards high resolution features. Reference: Shi et. al., "Real-Time Single Image and Video Super-Resolution ...", CVPR 2016, https://arxiv.org/abs/1609.05158
Examples
julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
2×3×4×1 Array{Float64,4}:
[:, :, 1, 1] =
11.1 12.1 13.1
21.1 22.1 23.1
[:, :, 2, 1] =
11.2 12.2 13.2
21.2 22.2 23.2
[:, :, 3, 1] =
11.3 12.3 13.3
21.3 22.3 23.3
[:, :, 4, 1] =
11.4 12.4 13.4
21.4 22.4 23.4
julia> pixel_shuffle(x, 2) # 4 channels used up as 2x upscaling of image dimensions
4×6×1×1 Array{Float64,4}:
[:, :, 1, 1] =
11.1 11.3 12.1 12.3 13.1 13.3
11.2 11.4 12.2 12.4 13.2 13.4
21.1 21.3 22.1 22.3 23.1 23.3
21.2 21.4 22.2 22.4 23.2 23.4
julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1]
3×6×1 Array{Float64, 3}:
[:, :, 1] =
1.1 1.2 1.3 1.4 1.5 1.6
2.1 2.2 2.3 2.4 2.5 2.6
3.1 3.2 3.3 3.4 3.5 3.6
julia> pixel_shuffle(y, 2) # 1D image, with 6 channels reduced to 3
6×3×1 Array{Float64,3}:
[:, :, 1] =
1.1 1.3 1.5
1.2 1.4 1.6
2.1 2.3 2.5
2.2 2.4 2.6
3.1 3.3 3.5
3.2 3.4 3.6Batched Operations
NNlib.batched_mul — Functionbatched_mul(A, B) -> C
A ⊠ B # \boxtimesBatched matrix multiplication. Result has C[:,:,k] == A[:,:,k] * B[:,:,k] for all k. If size(B,3) == 1 then instead C[:,:,k] == A[:,:,k] * B[:,:,1], and similarly for A.
To transpose each matrix, apply batched_transpose to the array, or batched_adjoint for conjugate-transpose:
julia> A, B = randn(2,5,17), randn(5,9,17);
julia> A ⊠ B |> size
(2, 9, 17)
julia> batched_adjoint(A) |> size
(5, 2, 17)
julia> batched_mul(A, batched_adjoint(randn(9,5,17))) |> size
(2, 9, 17)
julia> A ⊠ randn(5,9,1) |> size
(2, 9, 17)
julia> batched_transpose(A) == PermutedDimsArray(A, (2,1,3))
trueThe equivalent PermutedDimsArray may be used in place of batched_transpose. Other permutations are also handled by BLAS, provided that the batch index k is not the first dimension of the underlying array. Thus PermutedDimsArray(::Array, (1,3,2)) and PermutedDimsArray(::Array, (3,1,2)) are fine.
However, A = PermutedDimsArray(::Array, (3,2,1)) is not acceptable to BLAS, since the batch dimension is the contiguous one: stride(A,3) == 1. This will be copied, as doing so is faster than batched_mul_generic!.
Both this copy and batched_mul_generic! produce @debug messages, and setting for instance ENV["JULIA_DEBUG"] = NNlib will display them.
batched_mul(A::Array{T,3}, B::Matrix)
batched_mul(A::Matrix, B::Array{T,3})
A ⊠ BThis is always matrix-matrix multiplication, but either A or B may lack a batch index.
When
Bis a matrix, result hasC[:,:,k] == A[:,:,k] * B[:,:]for allk.When
Ais a matrix, thenC[:,:,k] == A[:,:] * B[:,:,k]. This can also be done by reshaping and calling*, for instanceA ⊡ Busing TensorCore.jl, but is implemented here usingbatched_gemminstead ofgemm.
julia> randn(16,8,32) ⊠ randn(8,4) |> size
(16, 4, 32)
julia> randn(16,8,32) ⊠ randn(8,4,1) |> size # equivalent
(16, 4, 32)
julia> randn(16,8) ⊠ randn(8,4,32) |> size
(16, 4, 32)See also batched_vec to regard B as a batch of vectors, A[:,:,k] * B[:,k].
NNlib.batched_mul! — Functionbatched_mul!(C, A, B) -> C
batched_mul!(C, A, B, α=1, β=0)In-place batched matrix multiplication, equivalent to mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β) for all k. If size(B,3) == 1 then every batch uses B[:,:,1] instead.
This will call batched_gemm! whenever possible. For real arrays this means that, for X ∈ [A,B,C], either strides(X,1)==1 or strides(X,2)==1, the latter may be caused by batched_transpose or by for instance PermutedDimsArray(::Array, (3,1,2)). Unlike batched_mul this will never make a copy.
For complex arrays, the wrapper made by batched_adjoint must be outermost to be seen. In this case the strided accepted by BLAS are more restricted, if stride(C,1)==1 then only stride(AorB::BatchedAdjoint,2) == 1 is accepted.
NNlib.batched_adjoint — Functionbatched_transpose(A::AbstractArray{T,3})
batched_adjoint(A)Equivalent to applying transpose or adjoint to each matrix A[:,:,k].
These exist to control how batched_mul behaves, as it operates on such matrix slices of an array with ndims(A)==3.
PermutedDimsArray(A, (2,1,3)) is equivalent to batched_transpose(A), and is also understood by batched_mul (and more widely supported elsewhere).
BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}Lazy wrappers analogous to Transpose and Adjoint, returned by batched_transpose etc.
NNlib.batched_transpose — Functionbatched_transpose(A::AbstractArray{T,3})
batched_adjoint(A)Equivalent to applying transpose or adjoint to each matrix A[:,:,k].
These exist to control how batched_mul behaves, as it operates on such matrix slices of an array with ndims(A)==3.
PermutedDimsArray(A, (2,1,3)) is equivalent to batched_transpose(A), and is also understood by batched_mul (and more widely supported elsewhere).
BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}Lazy wrappers analogous to Transpose and Adjoint, returned by batched_transpose etc.