Utility Functions
Flux contains some utility functions for working with data; these functions help create inputs for your models or batch your dataset. Other functions can be used to initialize your layers or to regularly execute callback functions.
Working with Data
Flux.unsqueeze
— Functionunsqueeze(xs, dim)
Return xs
reshaped into an array one dimensionality higher than xs
, where dim
indicates in which dimension xs
is extended.
Examples
julia> Flux.unsqueeze([1 2; 3 4], 2)
2×1×2 Array{Int64, 3}:
[:, :, 1] =
1
3
[:, :, 2] =
2
4
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Vector{Vector{Int64}}:
[1, 2]
[3, 4]
[5, 6]
julia> Flux.unsqueeze(xs, 1)
1×3 Matrix{Vector{Int64}}:
[1, 2] [3, 4] [5, 6]
unsqueeze(dim)
Returns a function which, acting on an array, inserts a dimension of size 1 at dim
.
Examples
julia> rand(21, 22, 23) |> Flux.unsqueeze(2) |> size
(21, 1, 22, 23)
julia> m = Chain(Flux.unsqueeze(3), Flux.unsqueeze(4), Conv((3,3), 1=>7, pad=SamePad()));
julia> rand(Float32, 10, 10) |> m |> size
(10, 10, 7, 1)
Flux.stack
— Functionstack(xs, dim)
Concatenate the given Array
of Array
s xs
into a single Array
along the given dimension dim
.
Examples
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Vector{Vector{Int64}}:
[1, 2]
[3, 4]
[5, 6]
julia> Flux.stack(xs, 1)
3×2 Matrix{Int64}:
1 2
3 4
5 6
julia> cat(xs, dims=1)
3-element Vector{Vector{Int64}}:
[1, 2]
[3, 4]
[5, 6]
Flux.unstack
— Functionunstack(xs, dim)
Unroll the given xs
into an Array
of Array
s along the given dimension dim
.
Examples
julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2)
4-element Vector{Vector{Int64}}:
[1, 2]
[3, 4]
[5, 6]
[7, 8]
Flux.chunk
— Functionchunk(xs, n)
Split xs
into n
parts.
Examples
julia> Flux.chunk(1:10, 3)
3-element Vector{UnitRange{Int64}}:
1:4
5:8
9:10
julia> Flux.chunk(collect(1:10), 3)
3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}:
[1, 2, 3, 4]
[5, 6, 7, 8]
[9, 10]
Flux.frequencies
— Functionfrequencies(xs)
Count the number of times that each element of xs
appears.
Examples
julia> Flux.frequencies(['a','b','b'])
Dict{Char, Int64} with 2 entries:
'a' => 1
'b' => 2
Flux.batch
— Functionbatch(xs)
Batch the arrays in xs
into a single array.
Examples
julia> Flux.batch([[1,2,3],[4,5,6]])
3×2 Matrix{Int64}:
1 4
2 5
3 6
Flux.batchseq
— Functionbatchseq(seqs, pad)
Take a list of N
sequences, and turn them into a single sequence where each item is a batch of N
. Short sequences will be padded by pad
.
Examples
julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
3-element Vector{Vector{Int64}}:
[1, 4]
[2, 5]
[3, 0]
Base.rpad
— MethodReturn the given sequence padded with p
up to a maximum length of n
.
Examples
julia> rpad([1, 2], 4, 0)
4-element Vector{Int64}:
1
2
0
0
julia> rpad([1, 2, 3], 2, 0)
3-element Vector{Int64}:
1
2
3
Layer Initialization
These are primarily useful if you are planning to write your own layers. Flux initializes convolutional layers and recurrent cells with glorot_uniform
by default. To change the default on an applicable layer, pass the desired function with the init
keyword. For example:
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
Conv((3, 3), 1=>8, relu)
Flux.glorot_uniform
— Functionglorot_uniform([rng=GLOBAL_RNG], dims...)
Return an Array
of size dims
containing random variables taken from a uniform distribution in the interval $[-x, x]$, where x = sqrt(6 / (fan_in + fan_out))
.
This method is described in [1] and also known as Xavier initialization.
Examples
julia> Flux.glorot_uniform(2, 3)
2×3 Matrix{Float32}:
0.601094 -0.57414 -0.814925
0.900868 0.805994 0.057514
See also
- glorot initialization using normal distribution:
glorot_normal
- kaiming initialization using normal distribution:
kaiming_normal
- kaiming initialization using uniform distribution:
kaiming_uniform
- sparse initialization:
sparse_init
- calculation of
fan_in
andfan_out
:nfan
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
Flux.glorot_normal
— Functionglorot_normal([rng=GLOBAL_RNG], dims...)
Return an Array
of size dims
containing random variables taken from a normal distribution with mean 0 and standard deviation sqrt(2 / (fan_in + fan_out))
.
This method is described in [1] and also known as Xavier initialization.
Examples
julia> Flux.glorot_normal(3, 2)
3×2 Matrix{Float32}:
0.429505 -0.0852891
0.523935 0.371009
-0.223261 0.188052
See also
- glorot initialization using uniform distribution:
glorot_uniform
- kaiming initialization using normal distribution:
kaiming_normal
- kaiming initialization using uniform distribution:
kaiming_uniform
- sparse initialization:
sparse_init
- calculation of
fan_in
andfan_out
:nfan
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
Flux.kaiming_uniform
— Functionkaiming_uniform([rng=GLOBAL_RNG], dims...; gain = √2)
Return an Array
of size dims
containing random variables taken from a uniform distribution in the interval [-x, x]
, where x = gain * sqrt(3/fan_in)
.
This method is described in [1] and also known as He initialization.
Examples
julia> Flux.kaiming_uniform(3, 2)
3×2 Matrix{Float32}:
0.950413 1.27439
1.4244 -1.28851
-0.907795 0.0909376
See also
- kaiming initialization using normal distribution:
kaiming_normal
- glorot initialization using normal distribution:
glorot_normal
- glorot initialization using uniform distribution:
glorot_uniform
- sparse initialization:
sparse_init
- calculation of
fan_in
andfan_out
:nfan
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Flux.kaiming_normal
— Functionkaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2)
Return an Array
of size dims
containing random variables taken from a normal distribution with mean 0 and standard deviation gain * sqrt(fan_in)
.
This method is described in [1] and also known as He initialization.
Examples
julia> Flux.kaiming_normal(3, 2)
3×2 Matrix{Float32}:
0.679107 -0.134854
0.828413 0.586617
-0.353007 0.297336
See also
- kaiming initialization using uniform distribution:
kaiming_uniform
- glorot initialization using normal distribution:
glorot_normal
- glorot initialization using uniform distribution:
glorot_uniform
- sparse initialization:
sparse_init
- calculation of
fan_in
andfan_out
:nfan
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Flux.orthogonal
— Functionorthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
Return an Array
of size dims
which is a (semi) orthogonal matrix, as described in [1].
The input must have at least 2 dimensions. For length(dims) > 2
, a prod(dims[1:(end - 1)])
by dims[end]
orthogonal matrix is computed before reshaping it to the original dimensions.
Examples
julia> W = Flux.orthogonal(5, 7);
julia> summary(W)
"5×7 Matrix{Float32}"
julia> W * W' ≈ I(5)
true
julia> W2 = Flux.orthogonal(7, 5);
julia> W2 * W2' ≈ I(7)
false
julia> W2' * W2 ≈ I(5)
true
julia> W3 = Flux.orthogonal(3, 3, 2, 4);
julia> transpose(reshape(W3, :, 4)) * reshape(W3, :, 4) ≈ I(4)
true
See also
- kaiming initialization using normal distribution:
kaiming_normal
- kaiming initialization using uniform distribution:
kaiming_uniform
- glorot initialization using normal distribution:
glorot_normal
- glorot initialization using uniform distribution:
glorot_uniform
- sparse initialization:
sparse_init
References
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
Flux.sparse_init
— Functionsparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)
Return an Array
of size dims
where each column contains a fixed fraction of zero elements given by sparsity
. Non-zero elements are normally distributed with a mean of zero and standard deviation std
.
This method is described in [1].
Examples
julia> Flux.sparse_init(3, 2, sparsity=0.1)
3×2 Matrix{Float32}:
0.00828413 0.0
-0.00353007 0.00297336
0.0 0.00586617
See also
- kaiming initialization using normal distribution:
kaiming_normal
- kaiming initialization using uniform distribution:
kaiming_uniform
- glorot initialization using normal distribution:
glorot_normal
- glorot initialization using uniform distribution:
glorot_uniform
References
[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th International Conference on International Conference on Machine Learning. 2010.
Changing the type of model parameters
Flux.f64
— Functionf64(m)
Convert the eltype
of model's parameters to Float64
.
Flux.f32
— Functionf32(m)
Convert the eltype
of model's parameters to Float32
.
The default eltype
for models is Float32
since models are often trained/run on GPUs. The eltype
of model m
can be changed to Float64
by f64(m)
, or to Float32
by f32(m)
.
Model Building
Flux provides some utility functions to help you generate models in an automated fashion.
outputsize
enables you to calculate the output sizes of layers like Conv
when applied to input samples of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. outputsize(f, inputsize)
works for all layers (including custom layers) out of the box. By default, inputsize
expects the batch dimension, but you can exclude the batch size with outputsize(f, inputsize; padbatch=true)
(assuming it to be one).
Using this utility function lets you automate model building for various inputs like so:
"""
make_model(width, height, inchannels, nclasses;
layer_config = [16, 16, 32, 32, 64, 64])
Create a CNN for a given set of configuration parameters.
# Arguments
- `width`: the input image width
- `height`: the input image height
- `inchannels`: the number of channels in the input image
- `nclasses`: the number of output classes
- `layer_config`: a vector of the number of filters per each conv layer
"""
function make_model(width, height, inchannels, nclasses;
layer_config = [16, 16, 32, 32, 64, 64])
# construct a vector of conv layers programmatically
conv_layers = [Conv((3, 3), inchannels => layer_config[1])]
for (infilters, outfilters) in zip(layer_config, layer_config[2:end])
push!(conv_layers, Conv((3, 3), infilters => outfilters))
end
# compute the output dimensions for the conv layers
# use padbatch=true to set the batch dimension to 1
conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true)
# the input dimension to Dense is programatically calculated from
# width, height, and nchannels
return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses))
end
Flux.outputsize
— Functionoutputsize(m, x_size, y_size, ...; padbatch=false)
For model or layer m
accepting multiple arrays as input, this returns size(m((x, y, ...)))
given size_x = size(x)
, etc.
Examples
julia> x, y = rand(Float32, 5, 64), rand(Float32, 7, 64);
julia> par = Parallel(vcat, Dense(5, 9), Dense(7, 11));
julia> Flux.outputsize(par, (5, 64), (7, 64))
(20, 64)
julia> m = Chain(par, Dense(20, 13), softmax);
julia> Flux.outputsize(m, (5,), (7,); padbatch=true)
(13, 1)
julia> par(x, y) == par((x, y)) == Chain(par, identity)((x, y))
true
Notice that Chain
only accepts multiple arrays as a tuple, while Parallel
also accepts them as multiple arguments; outputsize
always supplies the tuple.
Model Abstraction
Flux.modules
— Functionmodules(m)
Return an iterator over non-leaf objects that can be reached by recursing m
over the children given by functor
.
Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters (e.g. the weights but not the biases).
Examples
julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu))
Chain(Dense(784, 64), BatchNorm(64, relu))
julia> m2 = Chain(m1, Dense(64, 10))
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
julia> Flux.modules(m2)
5-element Vector{Any}:
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
Chain(Dense(784, 64), BatchNorm(64, relu))
Dense(784, 64)
BatchNorm(64, relu)
Dense(64, 10)
julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
Flux.destructure
— Functiondestructure(m)
Flatten a model's parameters into a single weight vector.
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> θ, re = destructure(m);
julia> θ
67-element Vector{Float32}:
-0.1407104
...
The second return value re
allows you to reconstruct the original network after making modifications to the weight vector (for example, with a hypernetwork).
julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Flux.nfan
— Functionnfan(n_out, n_in=1) -> Tuple
nfan(dims...)
nfan(dims::Tuple)
For a layer characterized by dimensions dims
, return a tuple (fan_in, fan_out)
, where fan_in
is the number of input neurons connected to an output one, and fan_out
is the number of output neurons connected to an input one.
This function is mainly used by weight initializers, e.g., kaiming_normal
.
Examples
julia> layer = Dense(10, 20)
Dense(10, 20)
julia> Flux.nfan(size(layer.W))
(10, 20)
julia> layer = Conv((3, 3), 2=>10)
Conv((3, 3), 2=>10)
julia> Flux.nfan(size(layer.weight))
(18, 90)
Callback Helpers
Flux.throttle
— Functionthrottle(f, timeout; leading=true, trailing=false)
Return a function that when invoked, will only be triggered at most once during timeout
seconds.
Normally, the throttled function will run as much as it can, without ever going more than once per wait
duration; but if you'd like to disable the execution on the leading edge, pass leading=false
. To enable execution on the trailing edge, pass trailing=true
.
Flux.Optimise.stop
— Functionstop()
Call Flux.stop()
in a callback to indicate when a callback condition is met. This will trigger the train loop to stop and exit.
Examples
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
Flux.Optimise.skip
— Functionskip()
Call Flux.skip()
in a callback to indicate when a callback condition is met. This will trigger the train loop to skip the current data point and not update with the calculated gradient.
Examples
cb = function ()
loss() > 1e7 && Flux.skip()
end
Patience Helpers
Flux provides utilities for controlling your training procedure according to some monitored condition and a maximum patience
. For example, you can use early_stopping
to stop training when the model is converging or deteriorating, or you can use plateau
to check if the model is stagnating.
For example, below we create a pseudo-loss function that decreases, bottoms out, then increases. The early stopping trigger will break the loop before the loss increases too much.
# create a pseudo-loss that decreases for 4 calls, then starts increasing
# we call this like loss()
loss = let t = 0
() -> begin
t += 1
(t - 4) ^ 2
end
end
# create an early stopping trigger
# returns true when the loss increases for two consecutive steps
es = early_stopping(loss, 2; init_score = 9)
# this will stop at the 6th (4 decreasing + 2 increasing calls) epoch
@epochs 10 begin
es() && break
end
The keyword argument distance
of early_stopping
is a function of the form distance(best_score, score)
. By default distance
is -
, which implies that the monitored metric f
is expected to be decreasing and mimimized. If you use some increasing metric (e.g. accuracy), you can customize the distance
function: (best_score, score) -> score - best_score
.
# create a pseudo-accuracy that increases by 0.01 each time from 0 to 1
# we call this like acc()
acc = let v = 0
() -> v = max(1, v + 0.01)
end
# create an early stopping trigger for accuracy
es = early_stopping(acc, 3; delta = (best_score, score) -> score - best_score)
# this will iterate until the 10th epoch
@epochs 10 begin
es() && break
end
early_stopping
and plateau
are both built on top of patience
. You can use patience
to build your own triggers that use a patient counter. For example, if you want to trigger when the loss is below a threshold for several consecutive iterations:
threshold(f, thresh, delay) = patience(delay) do
f() < thresh
end
Both predicate
in patience
and f
in early_stopping
/ plateau
can accept extra arguments. You can pass such extra arguments to predicate
or f
through the returned function:
trigger = patience((a; b) -> a > b, 3)
# this will iterate until the 10th epoch
@epochs 10 begin
trigger(1; b = 2) && break
end
# this will stop at the 3rd epoch
@epochs 10 begin
trigger(3; b = 2) && break
end
Flux.patience
— Functionpatience(predicate, wait)
Return a function that internally counts by one when predicate(...) == true
, otherwise the count is reset to zero. If the count is greater than or equal to wait
, the function returns true
, otherwise it returns false
.
Examples
julia> loss() = rand();
julia> trigger = Flux.patience(() -> loss() < 1, 3);
julia> Flux.@epochs 10 begin
trigger() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
Flux.early_stopping
— Functionearly_stopping(f, delay; distance = -, init_score = 0, min_dist = 0)
Return a function that internally counts by one when distance(best_score, f(...)) <= min_dist
, where best_score
is the last seen best value of f(...)
. If the count is greater than or equal to delay
, the function returns true
, otherwise it returns false
. The count is reset when distance(best_score, f(...)) > min_dist
.
Examples
julia> loss = let l = 0
() -> l += 1
end; # pseudo loss function that returns increasing values
julia> es = Flux.early_stopping(loss, 3);
julia> Flux.@epochs 10 begin
es() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
Flux.plateau
— Functionplateau(f, width; distance = -, init_score = 0, min_dist = 1f-6)
Return a function that internally counts by one when abs(distance(last_score, f(...))) <= min_dist
, where last_score
holds the last value of f(...)
. If the count is greater than or equal to width
, the function returns true
, otherwise it returns false
. The count is reset when abs(distance(last_score, f(...))) > min_dist
.
Examples
julia> f = let v = 10
() -> v = v / abs(v) - v
end; # -9, 8, -7, 6, ...
julia> trigger = Flux.plateau(f, 3; init_score=10, min_dist=18);
julia> Flux.@epochs 10 begin
trigger() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
[ Info: Epoch 4