Utility Functions

Flux contains some utility functions for working with data; these functions help create inputs for your models or batch your dataset. Other functions can be used to initialize your layers or to regularly execute callback functions.

Working with Data

Flux.unsqueezeFunction
unsqueeze(xs, dim)

Return xs reshaped into an array one dimensionality higher than xs, where dim indicates in which dimension xs is extended.

See also flatten, stack.

Examples

julia> Flux.unsqueeze([1 2; 3 4], 2)
2×1×2 Array{Int64, 3}:
[:, :, 1] =
 1
 3

[:, :, 2] =
 2
 4

julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Vector{Vector{Int64}}:
 [1, 2]
 [3, 4]
 [5, 6]

julia> Flux.unsqueeze(xs, 1)
1×3 Matrix{Vector{Int64}}:
 [1, 2]  [3, 4]  [5, 6]
source
unsqueeze(dim)

Returns a function which, acting on an array, inserts a dimension of size 1 at dim.

Examples

julia> rand(21, 22, 23) |> Flux.unsqueeze(2) |> size
(21, 1, 22, 23)

julia> m = Chain(Flux.unsqueeze(3), Flux.unsqueeze(4), Conv((3,3), 1=>7, pad=SamePad()));

julia> rand(Float32, 10, 10) |> m |> size
(10, 10, 7, 1)
source
Flux.stackFunction
stack(xs, dim)

Concatenate the given Array of Arrays xs into a single Array along the given dimension dim.

Examples

julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Vector{Vector{Int64}}:
 [1, 2]
 [3, 4]
 [5, 6]

julia> Flux.stack(xs, 1)
3×2 Matrix{Int64}:
 1  2
 3  4
 5  6

julia> cat(xs, dims=1)
3-element Vector{Vector{Int64}}:
 [1, 2]
 [3, 4]
 [5, 6]
source
Flux.unstackFunction
unstack(xs, dim)

Unroll the given xs into an Array of Arrays along the given dimension dim.

Examples

julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2)
4-element Vector{Vector{Int64}}:
 [1, 2]
 [3, 4]
 [5, 6]
 [7, 8]
source
Flux.chunkFunction
chunk(xs, n)

Split xs into n parts.

Examples

julia> Flux.chunk(1:10, 3)
3-element Vector{UnitRange{Int64}}:
 1:4
 5:8
 9:10

julia> Flux.chunk(collect(1:10), 3)
3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}:
 [1, 2, 3, 4]
 [5, 6, 7, 8]
 [9, 10]
source
Flux.frequenciesFunction
frequencies(xs)

Count the number of times that each element of xs appears.

Examples

julia> Flux.frequencies(['a','b','b'])
Dict{Char, Int64} with 2 entries:
  'a' => 1
  'b' => 2
source
Flux.batchFunction
batch(xs)

Batch the arrays in xs into a single array.

See also unbatch

Examples

julia> Flux.batch([[1,2,3],[4,5,6]])
3×2 Matrix{Int64}:
 1  4
 2  5
 3  6
source
Flux.unbatchFunction

unbatch(x)

Reverse of the batch operation, unstacking the last dimension of the array x.

See also unstack.

Examples

```jldoctest julia> Flux.unbatch([1 3 5 7; 2 4 6 8]) 4-element Vector{Vector{Int64}}: [1, 2] [3, 4] [5, 6] [7, 8]

source
Flux.batchseqFunction
batchseq(seqs, pad)

Take a list of N sequences, and turn them into a single sequence where each item is a batch of N. Short sequences will be padded by pad.

Examples

julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
3-element Vector{Vector{Int64}}:
 [1, 4]
 [2, 5]
 [3, 0]
source
Base.rpadMethod

Return the given sequence padded with p up to a maximum length of n.

Examples

julia> rpad([1, 2], 4, 0)
4-element Vector{Int64}:
 1
 2
 0
 0

julia> rpad([1, 2, 3], 2, 0)
3-element Vector{Int64}:
 1
 2
 3
source

Layer Initialization

These are primarily useful if you are planning to write your own layers. Flux initializes convolutional layers and recurrent cells with glorot_uniform by default. To change the default on an applicable layer, pass the desired function with the init keyword. For example:

julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
Conv((3, 3), 1 => 8, relu)  # 80 parameters
Flux.glorot_uniformFunction
glorot_uniform([rng=GLOBAL_RNG], dims...)

Return an Array of size dims containing random variables taken from a uniform distribution in the interval $[-x, x]$, where x = sqrt(6 / (fan_in + fan_out)).

This method is described in [1] and also known as Xavier initialization.

Examples

julia> Flux.glorot_uniform(2, 3)
2×3 Matrix{Float32}:
 0.601094  -0.57414   -0.814925
 0.900868   0.805994   0.057514

See also

References

[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.

source
Flux.glorot_normalFunction
glorot_normal([rng=GLOBAL_RNG], dims...)

Return an Array of size dims containing random variables taken from a normal distribution with mean 0 and standard deviation sqrt(2 / (fan_in + fan_out)).

This method is described in [1] and also known as Xavier initialization.

Examples

julia> Flux.glorot_normal(3, 2)
3×2 Matrix{Float32}:
  0.429505  -0.0852891
  0.523935   0.371009
 -0.223261   0.188052

See also

References

[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.

source
Flux.kaiming_uniformFunction
kaiming_uniform([rng=GLOBAL_RNG], dims...; gain = √2)

Return an Array of size dims containing random variables taken from a uniform distribution in the interval [-x, x], where x = gain * sqrt(3/fan_in).

This method is described in [1] and also known as He initialization.

Examples

julia> Flux.kaiming_uniform(3, 2)
3×2 Matrix{Float32}:
  0.950413   1.27439
  1.4244    -1.28851
 -0.907795   0.0909376

See also

References

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.

source
Flux.kaiming_normalFunction
kaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2)

Return an Array of size dims containing random variables taken from a normal distribution with mean 0 and standard deviation gain * sqrt(fan_in).

This method is described in [1] and also known as He initialization.

Examples

julia> Flux.kaiming_normal(3, 2)
3×2 Matrix{Float32}:
  0.679107  -0.134854
  0.828413   0.586617
 -0.353007   0.297336

See also

References

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.

source
Flux.orthogonalFunction
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)

Return an Array of size dims which is a (semi) orthogonal matrix, as described in [1].

The input must have at least 2 dimensions. For length(dims) > 2, a prod(dims[1:(end - 1)]) by dims[end] orthogonal matrix is computed before reshaping it to the original dimensions.

Examples

julia> W = Flux.orthogonal(5, 7);

julia> summary(W)
"5×7 Matrix{Float32}"

julia> W * W' ≈ I(5)
true

julia> W2 = Flux.orthogonal(7, 5);

julia> W2 * W2' ≈ I(7)
false

julia> W2' * W2 ≈ I(5)
true

julia> W3 = Flux.orthogonal(3, 3, 2, 4);

julia> transpose(reshape(W3, :, 4)) * reshape(W3, :, 4) ≈ I(4)
true

See also

References

[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120

source
Flux.sparse_initFunction
sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)

Return an Array of size dims where each column contains a fixed fraction of zero elements given by sparsity. Non-zero elements are normally distributed with a mean of zero and standard deviation std.

This method is described in [1].

Examples

julia> Flux.sparse_init(3, 2, sparsity=0.1)
3×2 Matrix{Float32}:
  0.00828413  0.0
 -0.00353007  0.00297336
  0.0         0.00586617

See also

References

[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th International Conference on International Conference on Machine Learning. 2010.

source

Changing the type of model parameters

Flux.f64Function
f64(m)

Convert the eltype of model's parameters to Float64.

source
Flux.f32Function
f32(m)

Convert the eltype of model's parameters to Float32.

source

The default eltype for models is Float32 since models are often trained/run on GPUs. The eltype of model m can be changed to Float64 by f64(m), or to Float32 by f32(m).

Model Building

Flux provides some utility functions to help you generate models in an automated fashion.

outputsize enables you to calculate the output sizes of layers like Conv when applied to input samples of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. outputsize(f, inputsize) works for all layers (including custom layers) out of the box. By default, inputsize expects the batch dimension, but you can exclude the batch size with outputsize(f, inputsize; padbatch=true) (assuming it to be one).

Using this utility function lets you automate model building for various inputs like so:

"""
    make_model(width, height, inchannels, nclasses;
               layer_config = [16, 16, 32, 32, 64, 64])

Create a CNN for a given set of configuration parameters.

# Arguments
- `width`: the input image width
- `height`: the input image height
- `inchannels`: the number of channels in the input image
- `nclasses`: the number of output classes
- `layer_config`: a vector of the number of filters per each conv layer
"""
function make_model(width, height, inchannels, nclasses;
                    layer_config = [16, 16, 32, 32, 64, 64])
  # construct a vector of conv layers programmatically
  conv_layers = [Conv((3, 3), inchannels => layer_config[1])]
  for (infilters, outfilters) in zip(layer_config, layer_config[2:end])
    push!(conv_layers, Conv((3, 3), infilters => outfilters))
  end

  # compute the output dimensions for the conv layers
  # use padbatch=true to set the batch dimension to 1
  conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true)

  # the input dimension to Dense is programatically calculated from
  #  width, height, and nchannels
  return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses))
end
Flux.outputsizeFunction
outputsize(m, x_size, y_size, ...; padbatch=false)

For model or layer m accepting multiple arrays as input, this returns size(m((x, y, ...))) given size_x = size(x), etc.

Examples

julia> x, y = rand(Float32, 5, 64), rand(Float32, 7, 64);

julia> par = Parallel(vcat, Dense(5, 9), Dense(7, 11));

julia> Flux.outputsize(par, (5, 64), (7, 64))
(20, 64)

julia> m = Chain(par, Dense(20, 13), softmax);

julia> Flux.outputsize(m, (5,), (7,); padbatch=true)
(13, 1)

julia> par(x, y) == par((x, y)) == Chain(par, identity)((x, y))
true

Notice that Chain only accepts multiple arrays as a tuple, while Parallel also accepts them as multiple arguments; outputsize always supplies the tuple.

source

Model Abstraction

Flux.modulesFunction
modules(m)

Return an iterator over non-leaf objects that can be reached by recursing m over the children given by functor.

Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters (e.g. the weights but not the biases).

Examples

julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu));

julia> m2 = Chain(m1, Dense(64, 10))
Chain(
  Chain(
    Dense(784, 64),                     # 50_240 parameters
    BatchNorm(64, relu),                # 128 parameters, plus 128
  ),
  Dense(64, 10),                        # 650 parameters
)         # Total: 6 trainable arrays, 51_018 parameters,
          # plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB.

julia> Flux.modules(m2)
5-element Vector{Any}:
 Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))  # 51_018 parameters, plus 128 non-trainable
 Chain(Dense(784, 64), BatchNorm(64, relu))  # 50_368 parameters, plus 128 non-trainable
 Dense(784, 64)      # 50_240 parameters
 BatchNorm(64, relu)  # 128 parameters, plus 128 non-trainable
 Dense(64, 10)       # 650 parameters

julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
source
Flux.destructureFunction
destructure(m)

Flatten a model's parameters into a single weight vector.

julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)

julia> θ, re = destructure(m);

julia> θ
67-element Vector{Float32}:
-0.1407104
...

The second return value re allows you to reconstruct the original network after making modifications to the weight vector (for example, with a hypernetwork).

julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
source
Flux.nfanFunction
nfan(n_out, n_in=1) -> Tuple
nfan(dims...)
nfan(dims::Tuple)

For a layer characterized by dimensions dims, return a tuple (fan_in, fan_out), where fan_in is the number of input neurons connected to an output one, and fan_out is the number of output neurons connected to an input one.

This function is mainly used by weight initializers, e.g., kaiming_normal.

Examples

julia> layer = Dense(10, 20);

julia> Flux.nfan(size(layer.weight))
(10, 20)

julia> layer = Conv((3, 3), 2=>10);

julia> Flux.nfan(size(layer.weight))
(18, 90)
source

Callback Helpers

Flux.throttleFunction
throttle(f, timeout; leading=true, trailing=false)

Return a function that when invoked, will only be triggered at most once during timeout seconds.

Normally, the throttled function will run as much as it can, without ever going more than once per wait duration; but if you'd like to disable the execution on the leading edge, pass leading=false. To enable execution on the trailing edge, pass trailing=true.

source
Flux.Optimise.stopFunction
stop()

Call Flux.stop() in a callback to indicate when a callback condition is met. This will trigger the train loop to stop and exit.

Examples

cb = function ()
  accuracy() > 0.9 && Flux.stop()
end
source
Flux.Optimise.skipFunction
skip()

Call Flux.skip() in a callback to indicate when a callback condition is met. This will trigger the train loop to skip the current data point and not update with the calculated gradient.

Examples

cb = function ()
  loss() > 1e7 && Flux.skip()
end
source

Patience Helpers

Flux provides utilities for controlling your training procedure according to some monitored condition and a maximum patience. For example, you can use early_stopping to stop training when the model is converging or deteriorating, or you can use plateau to check if the model is stagnating.

For example, below we create a pseudo-loss function that decreases, bottoms out, then increases. The early stopping trigger will break the loop before the loss increases too much.

# create a pseudo-loss that decreases for 4 calls, then starts increasing
# we call this like loss()
loss = let t = 0
  () -> begin
    t += 1
    (t - 4) ^ 2
  end
end

# create an early stopping trigger
# returns true when the loss increases for two consecutive steps
es = early_stopping(loss, 2; init_score = 9)

# this will stop at the 6th (4 decreasing + 2 increasing calls) epoch
@epochs 10 begin
  es() && break
end

The keyword argument distance of early_stopping is a function of the form distance(best_score, score). By default distance is -, which implies that the monitored metric f is expected to be decreasing and mimimized. If you use some increasing metric (e.g. accuracy), you can customize the distance function: (best_score, score) -> score - best_score.

# create a pseudo-accuracy that increases by 0.01 each time from 0 to 1
# we call this like acc()
acc = let v = 0
  () -> v = max(1, v + 0.01)
end

# create an early stopping trigger for accuracy
es = early_stopping(acc, 3; delta = (best_score, score) -> score - best_score)

# this will iterate until the 10th epoch
@epochs 10 begin
  es() && break
end

early_stopping and plateau are both built on top of patience. You can use patience to build your own triggers that use a patient counter. For example, if you want to trigger when the loss is below a threshold for several consecutive iterations:

threshold(f, thresh, delay) = patience(delay) do
  f() < thresh
end

Both predicate in patience and f in early_stopping / plateau can accept extra arguments. You can pass such extra arguments to predicate or f through the returned function:

trigger = patience((a; b) -> a > b, 3)

# this will iterate until the 10th epoch
@epochs 10 begin
  trigger(1; b = 2) && break
end

# this will stop at the 3rd epoch
@epochs 10 begin
  trigger(3; b = 2) && break
end
Flux.patienceFunction
patience(predicate, wait)

Return a function that internally counts by one when predicate(...) == true, otherwise the count is reset to zero. If the count is greater than or equal to wait, the function returns true, otherwise it returns false.

Examples

julia> loss() = rand();

julia> trigger = Flux.patience(() -> loss() < 1, 3);


julia> Flux.@epochs 10 begin
         trigger() && break
       end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
source
Flux.early_stoppingFunction
early_stopping(f, delay; distance = -, init_score = 0, min_dist = 0)

Return a function that internally counts by one when distance(best_score, f(...)) <= min_dist, where best_score is the last seen best value of f(...). If the count is greater than or equal to delay, the function returns true, otherwise it returns false. The count is reset when distance(best_score, f(...)) > min_dist.

Examples

julia> loss = let l = 0
         () -> l += 1
       end; # pseudo loss function that returns increasing values

julia> es = Flux.early_stopping(loss, 3);


julia> Flux.@epochs 10 begin
         es() && break
       end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
source
Flux.plateauFunction
plateau(f, width; distance = -, init_score = 0, min_dist = 1f-6)

Return a function that internally counts by one when abs(distance(last_score, f(...))) <= min_dist, where last_score holds the last value of f(...). If the count is greater than or equal to width, the function returns true, otherwise it returns false. The count is reset when abs(distance(last_score, f(...))) > min_dist.

Examples

julia> f = let v = 10
         () -> v = v / abs(v) - v
       end; # -9, 8, -7, 6, ...

julia> trigger = Flux.plateau(f, 3; init_score=10, min_dist=18);


julia> Flux.@epochs 10 begin
         trigger() && break
       end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
[ Info: Epoch 4
source