Utility Functions
Flux contains some utility functions for working with data; these functions help create inputs for your models or batch your dataset. Other functions can be used to initialize your layers or to regularly execute callback functions.
Working with Data
Flux.unsqueeze
— Functionunsqueeze(xs, dim)
Return xs
reshaped into an array one dimensionality higher than xs
, where dim
indicates in which dimension xs
is extended.
Examples
julia> Flux.unsqueeze([1 2; 3 4], 2)
2×1×2 Array{Int64,3}:
[:, :, 1] =
1
3
[:, :, 2] =
2
4
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Array{Array{Int64,1},1}:
[1, 2]
[3, 4]
[5, 6]
julia> Flux.unsqueeze(xs, 1)
1×3 Array{Array{Int64,1},2}:
[1, 2] [3, 4] [5, 6]
unsqueeze(dim)
Returns a function which, acting on an array, inserts a dimension of size 1 at dim
.
Examples
julia> rand(21, 22, 23) |> Flux.unsqueeze(2) |> size
(21, 1, 22, 23)
julia> m = Chain(Flux.unsqueeze(3), Flux.unsqueeze(4), Conv((3,3), 1=>7, pad=SamePad()));
julia> rand(Float32, 10, 10) |> m |> size
(10, 10, 7, 1)
Flux.stack
— Functionstack(xs, dim)
Concatenate the given Array
of Array
s xs
into a single Array
along the given dimension dim
.
Examples
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Array{Array{Int64,1},1}:
[1, 2]
[3, 4]
[5, 6]
julia> Flux.stack(xs, 1)
3×2 Array{Int64,2}:
1 2
3 4
5 6
julia> cat(xs, dims=1)
3-element Array{Array{Int64,1},1}:
[1, 2]
[3, 4]
[5, 6]
Flux.unstack
— Functionunstack(xs, dim)
Unroll the given xs
into an Array
of Array
s along the given dimension dim
.
Examples
julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2)
4-element Array{Array{Int64,1},1}:
[1, 2]
[3, 4]
[5, 6]
[7, 8]
Flux.chunk
— Functionchunk(xs, n)
Split xs
into n
parts.
Examples
julia> Flux.chunk(1:10, 3)
3-element Array{UnitRange{Int64},1}:
1:4
5:8
9:10
julia> Flux.chunk(collect(1:10), 3)
3-element Array{SubArray{Int64,1,Array{Int64,1},Tuple{UnitRange{Int64}},true},1}:
[1, 2, 3, 4]
[5, 6, 7, 8]
[9, 10]
Flux.frequencies
— Functionfrequencies(xs)
Count the number of times that each element of xs
appears.
Examples
julia> Flux.frequencies(['a','b','b'])
Dict{Char,Int64} with 2 entries:
'a' => 1
'b' => 2
Flux.batch
— Functionbatch(xs)
Batch the arrays in xs
into a single array.
Examples
julia> Flux.batch([[1,2,3],[4,5,6]])
3×2 Array{Int64,2}:
1 4
2 5
3 6
Flux.batchseq
— Functionbatchseq(seqs, pad)
Take a list of N
sequences, and turn them into a single sequence where each item is a batch of N
. Short sequences will be padded by pad
.
Examples
julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
3-element Array{Array{Int64,1},1}:
[1, 4]
[2, 5]
[3, 0]
Base.rpad
— MethodReturn the given sequence padded with p
up to a maximum length of n
.
Examples
julia> rpad([1, 2], 4, 0)
4-element Array{Int64,1}:
1
2
0
0
julia> rpad([1, 2, 3], 2, 0)
3-element Array{Int64,1}:
1
2
3
Layer Initialization
These are primarily useful if you are planning to write your own layers. Flux initializes convolutional layers and recurrent cells with glorot_uniform
by default. To change the default on an applicable layer, pass the desired function with the init
keyword. For example:
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
Conv((3, 3), 1=>8, relu)
Flux.glorot_uniform
— Functionglorot_uniform([rng=GLOBAL_RNG], dims...)
Return an Array
of size dims
containing random variables taken from a uniform distribution in the interval $[-x, x]$, where x = sqrt(6 / (fan_in + fan_out))
.
This method is described in [1] and also known as Xavier initialization.
Examples
julia> Flux.glorot_uniform(2, 3)
2×3 Array{Float32,2}:
0.601094 -0.57414 -0.814925
0.900868 0.805994 0.057514
See also
- glorot initialization using normal distribution:
glorot_normal
- kaiming initialization using normal distribution:
kaiming_normal
- kaiming initialization using uniform distribution:
kaiming_uniform
- sparse initialization:
sparse_init
- calculation of
fan_in
andfan_out
:nfan
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
Flux.glorot_normal
— Functionglorot_normal([rng=GLOBAL_RNG], dims...)
Return an Array
of size dims
containing random variables taken from a normal distribution with mean 0 and standard deviation sqrt(2 / (fan_in + fan_out))
.
This method is described in [1] and also known as Xavier initialization.
Examples
julia> Flux.glorot_normal(3, 2)
3×2 Array{Float32,2}:
0.429505 -0.0852891
0.523935 0.371009
-0.223261 0.188052
See also
- glorot initialization using uniform distribution:
glorot_uniform
- kaiming initialization using normal distribution:
kaiming_normal
- kaiming initialization using uniform distribution:
kaiming_uniform
- sparse initialization:
sparse_init
- calculation of
fan_in
andfan_out
:nfan
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
Flux.kaiming_uniform
— Functionkaiming_uniform([rng=GLOBAL_RNG], dims...; gain = √2)
Return an Array
of size dims
containing random variables taken from a uniform distribution in the interval [-x, x]
, where x = gain * sqrt(3/fan_in)
.
This method is described in [1] and also known as He initialization.
Examples
julia> Flux.kaiming_uniform(3, 2)
3×2 Array{Float32,2}:
0.950413 1.27439
1.4244 -1.28851
-0.907795 0.0909376
See also
- kaiming initialization using normal distribution:
kaiming_normal
- glorot initialization using normal distribution:
glorot_normal
- glorot initialization using uniform distribution:
glorot_uniform
- sparse initialization:
sparse_init
- calculation of
fan_in
andfan_out
:nfan
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Flux.kaiming_normal
— Functionkaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2)
Return an Array
of size dims
containing random variables taken from a normal distribution with mean 0 and standard deviation gain * sqrt(fan_in)
.
This method is described in [1] and also known as He initialization.
Examples
julia> Flux.kaiming_normal(3, 2)
3×2 Array{Float32,2}:
0.679107 -0.134854
0.828413 0.586617
-0.353007 0.297336
See also
- kaiming initialization using uniform distribution:
kaiming_uniform
- glorot initialization using normal distribution:
glorot_normal
- glorot initialization using uniform distribution:
glorot_uniform
- sparse initialization:
sparse_init
- calculation of
fan_in
andfan_out
:nfan
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Flux.orthogonal
— Functionorthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
Return an Array
of size dims
which is a (semi) orthogonal matrix, as described in [1].
The input must have at least 2 dimensions. For length(dims) > 2
, a prod(dims[1:(end - 1)])
by dims[end]
orthogonal matrix is computed before reshaping it to the original dimensions.
Examples
julia> W = Flux.orthogonal(5, 7);
julia> summary(W)
"5×7 Array{Float32,2}"
julia> W * W' ≈ I(5)
true
julia> W2 = Flux.orthogonal(7, 5);
julia> W2 * W2' ≈ I(7)
false
julia> W2' * W2 ≈ I(5)
true
julia> W3 = Flux.orthogonal(3, 3, 2, 4);
julia> transpose(reshape(W3, :, 4)) * reshape(W3, :, 4) ≈ I(4)
true
See also
- kaiming initialization using normal distribution:
kaiming_normal
- kaiming initialization using uniform distribution:
kaiming_uniform
- glorot initialization using normal distribution:
glorot_normal
- glorot initialization using uniform distribution:
glorot_uniform
- sparse initialization:
sparse_init
References
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
Flux.sparse_init
— Functionsparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)
Return an Array
of size dims
where each column contains a fixed fraction of zero elements given by sparsity
. Non-zero elements are normally distributed with a mean of zero and standard deviation std
.
This method is described in [1].
Examples
julia> Flux.sparse_init(3, 2, sparsity=0.1)
3×2 Array{Float32,2}:
0.00828413 0.0
-0.00353007 0.00297336
0.0 0.00586617
See also
- kaiming initialization using normal distribution:
kaiming_normal
- kaiming initialization using uniform distribution:
kaiming_uniform
- glorot initialization using normal distribution:
glorot_normal
- glorot initialization using uniform distribution:
glorot_uniform
References
[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th International Conference on International Conference on Machine Learning. 2010.
Model Building
Flux provides some utility functions to help you generate models in an automated fashion.
outputsize
enables you to calculate the output sizes of layers like Conv
when applied to input samples of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. outputsize(f, inputsize)
works for all layers (including custom layers) out of the box. By default, inputsize
expects the batch dimension, but you can exclude the batch size with outputsize(f, inputsize; padbatch=true)
(assuming it to be one).
Using this utility function lets you automate model building for various inputs like so:
"""
make_model(width, height, inchannels, nclasses;
layer_config = [16, 16, 32, 32, 64, 64])
Create a CNN for a given set of configuration parameters.
# Arguments
- `width`: the input image width
- `height`: the input image height
- `inchannels`: the number of channels in the input image
- `nclasses`: the number of output classes
- `layer_config`: a vector of the number of filters per each conv layer
"""
function make_model(width, height, inchannels, nclasses;
layer_config = [16, 16, 32, 32, 64, 64])
# construct a vector of conv layers programmatically
conv_layers = [Conv((3, 3), inchannels => layer_config[1])]
for (infilters, outfilters) in zip(layer_config, layer_config[2:end])
push!(conv_layers, Conv((3, 3), infilters => outfilters))
end
# compute the output dimensions for the conv layers
# use padbatch=true to set the batch dimension to 1
conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true)
# the input dimension to Dense is programatically calculated from
# width, height, and nchannels
return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses))
end
Flux.outputsize
— Functionoutputsize(m, x_size, y_size, ...; padbatch=false)
For model or layer m
accepting multiple arrays as input, this returns size(m((x, y, ...)))
given size_x = size(x)
, etc.
Examples
julia> x, y = rand(Float32, 5, 64), rand(Float32, 7, 64);
julia> par = Parallel(vcat, Dense(5, 9), Dense(7, 11));
julia> Flux.outputsize(par, (5, 64), (7, 64))
(20, 64)
julia> m = Chain(par, Dense(20, 13), softmax);
julia> Flux.outputsize(m, (5,), (7,); padbatch=true)
(13, 1)
julia> par(x, y) == par((x, y)) == Chain(par, identity)((x, y))
true
Notice that Chain
only accepts multiple arrays as a tuple, while Parallel
also accepts them as multiple arguments; outputsize
always supplies the tuple.
Model Abstraction
Flux.modules
— Functionmodules(m)
Return an iterator over non-leaf objects that can be reached by recursing m
over the children given by functor
.
Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters (e.g. the weights but not the biases).
Examples
julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu))
Chain(Dense(784, 64), BatchNorm(64, relu))
julia> m2 = Chain(m1, Dense(64, 10))
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
julia> Flux.modules(m2)
5-element Array{Any,1}:
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
Chain(Dense(784, 64), BatchNorm(64, relu))
Dense(784, 64)
BatchNorm(64, relu)
Dense(64, 10)
julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
Flux.destructure
— Functiondestructure(m)
Flatten a model's parameters into a single weight vector.
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> θ, re = destructure(m);
julia> θ
67-element Array{Float32,1}:
-0.1407104
...
The second return value re
allows you to reconstruct the original network after making modifications to the weight vector (for example, with a hypernetwork).
julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Flux.nfan
— Functionnfan(n_out, n_in=1) -> Tuple
nfan(dims...)
nfan(dims::Tuple)
For a layer characterized by dimensions dims
, return a tuple (fan_in, fan_out)
, where fan_in
is the number of input neurons connected to an output one, and fan_out
is the number of output neurons connected to an input one.
This function is mainly used by weight initializers, e.g., kaiming_normal
.
Examples
julia> layer = Dense(10, 20)
Dense(10, 20)
julia> Flux.nfan(size(layer.W))
(10, 20)
julia> layer = Conv((3, 3), 2=>10)
Conv((3, 3), 2=>10)
julia> Flux.nfan(size(layer.weight))
(18, 90)
Callback Helpers
Flux.throttle
— Functionthrottle(f, timeout; leading=true, trailing=false)
Return a function that when invoked, will only be triggered at most once during timeout
seconds.
Normally, the throttled function will run as much as it can, without ever going more than once per wait
duration; but if you'd like to disable the execution on the leading edge, pass leading=false
. To enable execution on the trailing edge, pass trailing=true
.
Flux.Optimise.stop
— Functionstop()
Call Flux.stop()
in a callback to indicate when a callback condition is met. This will trigger the train loop to stop and exit.
Examples
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
Flux.Optimise.skip
— Functionskip()
Call Flux.skip()
in a callback to indicate when a callback condition is met. This will trigger the train loop to skip the current data point and not update with the calculated gradient.
Examples
cb = function ()
loss() > 1e7 && Flux.skip()
end