Utility Functions
Flux provides utility functions which can be used to initialize your layers or to regularly execute callback functions.
Layer Initialisation
Flux initialises convolutional layers and recurrent cells with glorot_uniform by default. Most layers accept a function as an init keyword, which replaces this default. For example:
julia> conv = Conv((3, 3), 3 => 2, relu; init=Flux.glorot_normal)
Conv((3, 3), 3 => 2, relu) # 56 parameters
julia> conv.bias
2-element Vector{Float32}:
0.0
0.0Note that init creates the weight array, but not the bias vector.
Many of the initialisation functions accept keywords such as gain, and a random number generator. To make it easy to pass these to layers, there are methods which return a function:
julia> Dense(4 => 5, tanh; init=Flux.glorot_uniform(gain=2))
Dense(4 => 5, tanh) # 25 parameters
julia> Dense(4 => 5, tanh; init=Flux.randn32(MersenneTwister(1)))
Dense(4 => 5, tanh) # 25 parametersFlux.glorot_uniform — Functionglorot_uniform([rng = default_rng_value()], size...; gain = 1) -> Array
glorot_uniform([rng]; kw...) -> FunctionReturn an Array{Float32} of the given size containing random numbers drawn from a uniform distribution on the interval $[-x, x]$, where x = gain * sqrt(6 / (fan_in + fan_out)).
This method is described in [1] and also known as Xavier initialization.
Examples
julia> Flux.glorot_uniform(3, 4) |> summary
"3×4 Matrix{Float32}"
julia> round.(extrema(Flux.glorot_uniform(10, 100)), digits=3)
(-0.232f0, 0.234f0)
julia> round.(extrema(Flux.glorot_uniform(100, 10)), digits=3)
(-0.233f0, 0.233f0)
julia> round.(extrema(Flux.glorot_uniform(100, 100)), digits=3)
(-0.173f0, 0.173f0)
julia> Dense(3 => 2, tanh; init = Flux.glorot_uniform(MersenneTwister(1)))
Dense(3 => 2, tanh) # 8 parameters
julia> ans.bias
2-element Vector{Float32}:
0.0
0.0References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
Flux.glorot_normal — Functionglorot_normal([rng = default_rng_value(), size...; gain = 1) -> Array
glorot_normal([rng]; kw...) -> FunctionReturn an Array{Float32} of the given size containing random numbers drawn from a normal distribution with standard deviation gain * sqrt(2 / (fan_in + fan_out)), using nfan.
This method is described in [1] and also known as Xavier initialization.
Examples
julia> using Statistics
julia> round(std(Flux.glorot_normal(10, 1000)), digits=3)
0.044f0
julia> round(std(Flux.glorot_normal(1000, 10)), digits=3)
0.044f0
julia> round(std(Flux.glorot_normal(1000, 1000)), digits=3)
0.032f0
julia> Dense(10 => 1000, tanh; init = Flux.glorot_normal(gain=100))
Dense(10 => 1000, tanh) # 11_000 parameters
julia> round(std(ans.weight), sigdigits=3)
4.45f0References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
Flux.kaiming_uniform — Functionkaiming_uniform([rng = default_rng_value()], size...; gain = √2) -> Array
kaiming_uniform([rng]; kw...) -> FunctionReturn an Array{Float32} of the given size containing random numbers drawn from a uniform distribution on the interval [-x, x], where x = gain * sqrt(3/fan_in) using nfan.
This method is described in [1] and also known as He initialization.
Examples
julia> round.(extrema(Flux.kaiming_uniform(100, 10)), digits=3)
(-0.774f0, 0.774f0)
julia> round.(extrema(Flux.kaiming_uniform(10, 100)), digits=3)
(-0.245f0, 0.244f0)
julia> round.(extrema(Flux.kaiming_uniform(100, 100)), digits=3)
(-0.245f0, 0.245f0)References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Flux.kaiming_normal — Functionkaiming_normal([rng = default_rng_value()], size...; gain = √2) -> Array
kaiming_normal([rng]; kw...) -> FunctionReturn an Array{Float32} of the given size containing random numbers taken from a normal distribution standard deviation gain / sqrt(fan_in), using nfan.
This method is described in [1] and also known as He initialization.
Examples
julia> using Statistics
julia> round(std(Flux.kaiming_normal(10, 1000)), digits=3)
0.045f0
julia> round(std(Flux.kaiming_normal(1000, 10)), digits=3)
0.447f0
julia> round(std(Flux.kaiming_normal(1000, 1000)), digits=3)
0.045f0References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Flux.truncated_normal — Functiontruncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
truncated_normal([rng]; kw...) -> FunctionReturn an Array{Float32} of the given size where each element is drawn from a truncated normal distribution. The numbers are distributed like filter(x -> lo<=x<=hi, mean .+ std .* randn(100)).
The values are generated by sampling a Uniform(0, 1) (rand()) and then applying the inverse CDF of the truncated normal distribution. This method works best when lo ≤ mean ≤ hi.
Examples
julia> using Statistics
julia> Flux.truncated_normal(3, 4) |> summary
"3×4 Matrix{Float32}"
julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3)
(-2.0f0, 2.0f0)
julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
1.0f0Flux.orthogonal — Functionorthogonal([rng = default_rng_value()], size...; gain = 1) -> Array
orthogonal([rng]; kw...) -> FunctionReturn an Array{Float32} of the given size which is a (semi) orthogonal matrix, as described in [1].
Cannot construct a vector, i.e. length(size) == 1 is forbidden. For length(size) > 2, a prod(size[1:(end - 1)]) by size[end] orthogonal matrix is computed before reshaping it to the original dimensions.
Examples
julia> W = Flux.orthogonal(5, 7);
julia> summary(W)
"5×7 Matrix{Float32}"
julia> W * W' ≈ I(5)
true
julia> W2 = Flux.orthogonal(7, 5);
julia> W2 * W2' ≈ I(7)
false
julia> W2' * W2 ≈ I(5)
true
julia> W3 = Flux.orthogonal(3, 3, 2, 4);
julia> transpose(reshape(W3, :, 4)) * reshape(W3, :, 4) ≈ I(4)
trueReferences
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
Flux.sparse_init — Functionsparse_init([rng = default_rng_value()], rows, cols; sparsity, std = 0.01) -> Array
sparse_init([rng]; kw...) -> FunctionReturn a Matrix{Float32} of size rows, cols where each column contains a fixed fraction of zero elements given by sparsity. Non-zero elements are normally distributed with a mean of zero and standard deviation std.
This method is described in [1].
Examples
julia> count(iszero, Flux.sparse_init(10, 10, sparsity=1/5))
20
julia> sum(0 .== Flux.sparse_init(10, 11, sparsity=0.9), dims=1)
1×11 Matrix{Int64}:
9 9 9 9 9 9 9 9 9 9 9
julia> Dense(3 => 10, tanh; init=Flux.sparse_init(sparsity=0.5))
Dense(3 => 10, tanh) # 40 parameters
julia> count(iszero, ans.weight, dims=1)
1×3 Matrix{Int64}:
5 5 5References
[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th International Conference on International Conference on Machine Learning. 2010.
Flux.identity_init — Functionidentity_init(size...; gain=1, shift=0) -> Array
identity_init(; kw...) -> FunctionReturn an Array{Float32} of the given size which yields an identity mapping when used as parameters in most Flux layers. Use gain to scale the identity by a constant.
Often useful in the context of transfer learning, i.e when one wants to add more capacity to a model but start from the same mapping.
Has the following behaviour
- 1D: A
Vectorofzeros(useful for an identity bias) - 2D: An identity matrix (useful for an identity matrix multiplication)
- More than 2D: A dense block array of center tap spatial filters (useful for an identity convolution)
Some caveats:
Not all layers will be identity mapping when used with this init. Exceptions include recurrent layers and normalization layers.
Layers must have
input_size == output_sizefor identity mapping to be possible. When this is not the case, extra dimensions of the array are padded with zeros.For convolutional layers, in addition to the above, the kernel sizes must also be odd and padding must be applied so that output feature maps have the same size as input feature maps, e.g by using
SamePad.
Use keyword shift (integer or tuple) to apply circular shift to the output, equivalent to Base.circshift(identity_init(size...), shift).
For consistency with other initialisers, it accepts rng::AbstractRNG as an optional first argument. But this is ignored, since the result is not random.
Examples
julia> Flux.identity_init(3,5)
3×5 Matrix{Float32}:
1.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0
0.0 0.0 1.0 0.0 0.0
julia> Dense(5 => 3, relu, init=Flux.identity_init)([1,-2,3,-4,5])
3-element Vector{Float32}:
1.0
0.0
3.0
julia> Flux.identity_init(3,3,2; gain=100)
3×3×2 Array{Float32, 3}:
[:, :, 1] =
0.0 0.0 0.0
100.0 0.0 0.0
0.0 0.0 0.0
[:, :, 2] =
0.0 0.0 0.0
0.0 100.0 0.0
0.0 0.0 0.0
julia> x4 = cat([1 2 3; 4 5 6; 7 8 9]; dims=4);
julia> Conv((2,2), 1 => 1, init=Flux.identity_init(gain=10), pad=SamePad())(x4)
3×3×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
10.0 20.0 30.0
40.0 50.0 60.0
70.0 80.0 90.0Flux.ones32 — Functionones32(size...) = ones(Float32, size...)Return an Array{Float32} of the given size filled with 1s.
Flux.zeros32 — Functionzeros32(size...) = zeros(Float32, size...)Return an Array{Float32} of the given size filled with 0s.
Flux.rand32 — Functionrand32([rng], size...)Return an Array{Float32} of the given size, filled like rand. When the size is not provided, rand32(rng::AbstractRNG) returns a function.
Flux.randn32 — Functionrandn32([rng], size...)Return an Array{Float32} of the given size, filled like randn. When the size is not provided, randn32(rng::AbstractRNG) returns a function.
Flux.rng_from_array — Functionrng_from_array([x])Create an instance of the RNG most appropriate for x. The current defaults are:
x isa CuArray:CUDA.default_rng(), else:x isa AbstractArray, or noxprovided:- Julia version is < 1.7:
Random.GLOBAL_RNG - Julia version is >= 1.7:
Random.default_rng()
- Julia version is < 1.7:
Missing docstring for Flux.default_rng_value. Check Documenter's build log for details.
Changing the type of model parameters
The default eltype for models is Float32 since models are often trained/run on GPUs. The eltype of model m can be changed to Float64 by f64(m):
Flux.f64 — Functionf64(m)Converts the eltype of model's parameters to Float64. Recurses into structs marked with @functor.
Flux.f32 — Functionf32(m)Converts the eltype of model's parameters to Float32 (which is Flux's default). Recurses into structs marked with @functor.
Model Building
Flux provides some utility functions to help you generate models in an automated fashion.
Flux.outputsize enables you to calculate the output sizes of layers like Conv when applied to input samples of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. outputsize(f, inputsize) works for all layers (including custom layers) out of the box. By default, inputsize expects the batch dimension, but you can exclude the batch size with outputsize(f, inputsize; padbatch=true) (assuming it to be one).
Using this utility function lets you automate model building for various inputs like so:
"""
make_model(width, height, inchannels, nclasses;
layer_config = [16, 16, 32, 32, 64, 64])
Create a CNN for a given set of configuration parameters.
# Arguments
- `width`: the input image width
- `height`: the input image height
- `inchannels`: the number of channels in the input image
- `nclasses`: the number of output classes
- `layer_config`: a vector of the number of filters per each conv layer
"""
function make_model(width, height, inchannels, nclasses;
layer_config = [16, 16, 32, 32, 64, 64])
# construct a vector of conv layers programmatically
conv_layers = [Conv((3, 3), inchannels => layer_config[1])]
for (infilters, outfilters) in zip(layer_config, layer_config[2:end])
push!(conv_layers, Conv((3, 3), infilters => outfilters))
end
# compute the output dimensions for the conv layers
# use padbatch=true to set the batch dimension to 1
conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true)
# the input dimension to Dense is programatically calculated from
# width, height, and nchannels
return Chain(conv_layers..., Dense(prod(conv_outsize) => nclasses))
endFlux.outputsize — Functionoutputsize(m, x_size, y_size, ...; padbatch=false)For model or layer m accepting multiple arrays as input, this returns size(m((x, y, ...))) given size_x = size(x), etc.
Examples
julia> x, y = rand(Float32, 5, 64), rand(Float32, 7, 64);
julia> par = Parallel(vcat, Dense(5, 9), Dense(7, 11));
julia> Flux.outputsize(par, (5, 64), (7, 64))
(20, 64)
julia> m = Chain(par, Dense(20, 13), softmax);
julia> Flux.outputsize(m, (5,), (7,); padbatch=true)
(13, 1)
julia> par(x, y) == par((x, y)) == Chain(par, identity)((x, y))
trueNotice that Chain only accepts multiple arrays as a tuple, while Parallel also accepts them as multiple arguments; outputsize always supplies the tuple.
Model Abstraction
Flux.modules — Functionmodules(m)Return an iterator over non-leaf objects that can be reached by recursing m over the children given by functor.
Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters (e.g. the weights but not the biases).
Examples
julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu));
julia> m2 = Chain(m1, Dense(64, 10))
Chain(
Chain(
Dense(784 => 64), # 50_240 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
),
Dense(64 => 10), # 650 parameters
) # Total: 6 trainable arrays, 51_018 parameters,
# plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB.
julia> Flux.modules(m2)
7-element Vector{Any}:
Chain(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10)) # 51_018 parameters, plus 128 non-trainable
(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))
Chain(Dense(784 => 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable
(Dense(784 => 64), BatchNorm(64, relu))
Dense(784 => 64) # 50_240 parameters
BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable
Dense(64 => 10) # 650 parameters
julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
julia> L2(m2) isa Float32
trueFlux.nfan — Functionnfan(n_out, n_in=1) -> Tuple
nfan(dims...)
nfan(dims::Tuple)For a layer characterized by dimensions dims, return a tuple (fan_in, fan_out), where fan_in is the number of input neurons connected to an output one, and fan_out is the number of output neurons connected to an input one.
This function is mainly used by weight initializers, e.g., kaiming_normal.
Examples
julia> layer = Dense(10, 20);
julia> Flux.nfan(size(layer.weight))
(10, 20)
julia> layer = Conv((3, 3), 2=>10);
julia> Flux.nfan(size(layer.weight))
(18, 90)Callback Helpers
Flux.throttle — Functionthrottle(f, timeout; leading=true, trailing=false)Return a function that when invoked, will only be triggered at most once during timeout seconds.
Normally, the throttled function will run as much as it can, without ever going more than once per wait duration; but if you'd like to disable the execution on the leading edge, pass leading=false. To enable execution on the trailing edge, pass trailing=true.
Examples
julia> a = Flux.throttle(() -> println("Flux"), 2);
julia> for i = 1:4 # a called in alternate iterations
a()
sleep(1)
end
Flux
FluxFlux.Optimise.stop — Functionstop()Call Flux.stop() in a callback to indicate when a callback condition is met. This will trigger the train loop to stop and exit.
Flux.stop() will be removed from Flux 0.14. It should be replaced with break in an ordinary for loop.
Examples
cb = function ()
accuracy() > 0.9 && Flux.stop()
endFlux.Optimise.skip — Functionskip()Call Flux.skip() in a callback to indicate when a callback condition is met. This will trigger the train loop to skip the current data point and not update with the calculated gradient.
Flux.skip() will be removed from Flux 0.14
Examples
cb = function ()
loss() > 1e7 && Flux.skip()
endPatience Helpers
Flux provides utilities for controlling your training procedure according to some monitored condition and a maximum patience. For example, you can use early_stopping to stop training when the model is converging or deteriorating, or you can use plateau to check if the model is stagnating.
For example, below we create a pseudo-loss function that decreases, bottoms out, and then increases. The early stopping trigger will break the loop before the loss increases too much.
# create a pseudo-loss that decreases for 4 calls, then starts increasing
# we call this like loss()
loss = let t = 0
() -> begin
t += 1
(t - 4) ^ 2
end
end
# create an early stopping trigger
# returns true when the loss increases for two consecutive steps
es = early_stopping(loss, 2; init_score = 9)
# this will stop at the 6th (4 decreasing + 2 increasing calls) epoch
@epochs 10 begin
es() && break
endThe keyword argument distance of early_stopping is a function of the form distance(best_score, score). By default distance is -, which implies that the monitored metric f is expected to be decreasing and minimized. If you use some increasing metric (e.g. accuracy), you can customize the distance function: (best_score, score) -> score - best_score.
# create a pseudo-accuracy that increases by 0.01 each time from 0 to 1
# we call this like acc()
acc = let v = 0
() -> v = max(1, v + 0.01)
end
# create an early stopping trigger for accuracy
es = early_stopping(acc, 3; delta = (best_score, score) -> score - best_score)
# this will iterate until the 10th epoch
@epochs 10 begin
es() && break
endearly_stopping and plateau are both built on top of patience. You can use patience to build your own triggers that use a patient counter. For example, if you want to trigger when the loss is below a threshold for several consecutive iterations:
threshold(f, thresh, delay) = patience(delay) do
f() < thresh
endBoth predicate in patience and f in early_stopping / plateau can accept extra arguments. You can pass such extra arguments to predicate or f through the returned function:
trigger = patience((a; b) -> a > b, 3)
# this will iterate until the 10th epoch
@epochs 10 begin
trigger(1; b = 2) && break
end
# this will stop at the 3rd epoch
@epochs 10 begin
trigger(3; b = 2) && break
endFlux.patience — Functionpatience(predicate, wait)Return a function that internally counts by one when predicate(...) == true, otherwise the count is reset to zero. If the count is greater than or equal to wait, the function returns true, otherwise it returns false.
Examples
julia> loss() = rand();
julia> trigger = Flux.patience(() -> loss() < 1, 3);
julia> for i in 1:10
@info "Epoch $i"
trigger() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3Flux.early_stopping — Functionearly_stopping(f, delay; distance = -, init_score = 0, min_dist = 0)Return a function that internally counts by one when distance(best_score, f(...)) <= min_dist, where best_score is the last seen best value of f(...). If the count is greater than or equal to delay, the function returns true, otherwise it returns false. The count is reset when distance(best_score, f(...)) > min_dist.
Examples
julia> loss = let l = 0
() -> l += 1
end; # pseudo loss function that returns increasing values
julia> es = Flux.early_stopping(loss, 3);
julia> for i in 1:10
@info "Epoch $i"
es() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3Flux.plateau — Functionplateau(f, width; distance = -, init_score = 0, min_dist = 1f-6)Return a function that internally counts by one when abs(distance(last_score, f(...))) <= min_dist, where last_score holds the last value of f(...). If the count is greater than or equal to width, the function returns true, otherwise it returns false. The count is reset when abs(distance(last_score, f(...))) > min_dist.
Examples
julia> f = let v = 10
() -> v = v / abs(v) - v
end; # -9, 8, -7, 6, ...
julia> trigger = Flux.plateau(f, 3; init_score=10, min_dist=18);
julia> for i in 1:10
@info "Epoch $i"
trigger() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
[ Info: Epoch 4