Utility Functions

Flux contains some utility functions for working with data; these functions help create inputs for your models or batch your dataset. Other functions can be used to initialize your layers or to regularly execute callback functions.

Working with Data

Flux.unsqueezeFunction
unsqueeze(xs, dim)

Return xs reshaped into an Array one dimensionality higher than xs, where dim indicates in which dimension xs is extended.

Examples

julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Array{Array{Int64,1},1}:
 [1, 2]
 [3, 4]
 [5, 6]

julia> Flux.unsqueeze(xs, 1)
1×3 Array{Array{Int64,1},2}:
 [1, 2]  [3, 4]  [5, 6]

julia> Flux.unsqueeze([1 2; 3 4], 2)
2×1×2 Array{Int64,3}:
[:, :, 1] =
 1
 3

[:, :, 2] =
 2
 4
source
Flux.stackFunction
stack(xs, dim)

Concatenate the given Array of Arrays xs into a single Array along the given dimension dim.

Examples

julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Array{Array{Int64,1},1}:
 [1, 2]
 [3, 4]
 [5, 6]

julia> Flux.stack(xs, 1)
3×2 Array{Int64,2}:
 1  2
 3  4
 5  6

julia> cat(xs, dims=1)
3-element Array{Array{Int64,1},1}:
 [1, 2]
 [3, 4]
 [5, 6]
source
Flux.unstackFunction
unstack(xs, dim)

Unroll the given xs into an Array of Arrays along the given dimension dim.

Examples

julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2)
4-element Array{Array{Int64,1},1}:
 [1, 2]
 [3, 4]
 [5, 6]
 [7, 8]
source
Flux.chunkFunction
chunk(xs, n)

Split xs into n parts.

Examples

julia> Flux.chunk(1:10, 3)
3-element Array{UnitRange{Int64},1}:
 1:4
 5:8
 9:10

julia> Flux.chunk(collect(1:10), 3)
3-element Array{SubArray{Int64,1,Array{Int64,1},Tuple{UnitRange{Int64}},true},1}:
 [1, 2, 3, 4]
 [5, 6, 7, 8]
 [9, 10]
source
Flux.frequenciesFunction
frequencies(xs)

Count the number of times that each element of xs appears.

Examples

julia> Flux.frequencies(['a','b','b'])
Dict{Char,Int64} with 2 entries:
  'a' => 1
  'b' => 2
source
Flux.batchFunction
batch(xs)

Batch the arrays in xs into a single array.

Examples

julia> Flux.batch([[1,2,3],[4,5,6]])
3×2 Array{Int64,2}:
 1  4
 2  5
 3  6
source
Flux.batchseqFunction
batchseq(seqs, pad)

Take a list of N sequences, and turn them into a single sequence where each item is a batch of N. Short sequences will be padded by pad.

Examples

julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
3-element Array{Array{Int64,1},1}:
 [1, 4]
 [2, 5]
 [3, 0]
source
Base.rpadMethod

Return the given sequence padded with p up to a maximum length of n.

Examples

julia> rpad([1, 2], 4, 0)
4-element Array{Int64,1}:
 1
 2
 0
 0

julia> rpad([1, 2, 3], 2, 0)
3-element Array{Int64,1}:
 1
 2
 3
source

Layer Initialization

These are primarily useful if you are planning to write your own layers. Flux initializes convolutional layers and recurrent cells with glorot_uniform by default. To change the default on an applicable layer, pass the desired function with the init keyword. For example:

julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
Conv((3, 3), 1=>8, relu)
Flux.glorot_uniformFunction
glorot_uniform([rng=GLOBAL_RNG], dims...)

Return an Array of size dims containing random variables taken from a uniform distribution in the interval $[-x, x]$, where x = sqrt(6 / (fan_in + fan_out)).

This method is described in [1] and also known as Xavier initialization.

Examples

julia> Flux.glorot_uniform(2, 3)
2×3 Array{Float32,2}:
 0.601094  -0.57414   -0.814925
 0.900868   0.805994   0.057514

See also

  • glorot initialization using normal distribution: glorot_normal
  • kaiming initialization using normal distribution: kaiming_normal
  • kaiming initialization using uniform distribution: kaiming_uniform
  • calculation of fan_in and fan_out: nfan

References

[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.

source
Flux.glorot_normalFunction
glorot_normal([rng=GLOBAL_RNG], dims...)

Return an Array of size dims containing random variables taken from a normal distribution with mean 0 and standard deviation sqrt(2 / (fan_in + fan_out)).

This method is described in [1] and also known as Xavier initialization.

Examples

julia> Flux.glorot_normal(3, 2)
3×2 Array{Float32,2}:
  0.429505  -0.0852891
  0.523935   0.371009
 -0.223261   0.188052

See also

References

[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.

source
Flux.kaiming_uniformFunction
kaiming_uniform([rng=GLOBAL_RNG], dims...; gain = √2)

Return an Array of size dims containing random variables taken from a uniform distribution in the interval [-x, x], where x = gain * sqrt(3/fan_in).

This method is described in [1] and also known as He initialization.

Examples

julia> Flux.kaiming_uniform(3, 2)
3×2 Array{Float32,2}:
  0.950413   1.27439
  1.4244    -1.28851
 -0.907795   0.0909376

See also

  • kaiming initialization using normal distribution: kaiming_normal
  • glorot initialization using normal distribution: glorot_normal
  • glorot initialization using uniform distribution: glorot_uniform
  • calculation of fan_in and fan_out: nfan

References

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.

source
Flux.kaiming_normalFunction
kaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2)

Return an Array of size dims containing random variables taken from a normal distribution with mean 0 and standard deviation gain * sqrt(fan_in).

This method is described in [1] and also known as He initialization.

Examples

julia> Flux.kaiming_normal(3, 2)
3×2 Array{Float32,2}:
  0.679107  -0.134854
  0.828413   0.586617
 -0.353007   0.297336

See also

  • kaiming initialization using uniform distribution: kaiming_uniform
  • glorot initialization using normal distribution: glorot_normal
  • glorot initialization using uniform distribution: glorot_uniform
  • calculation of fan_in and fan_out: nfan

References

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.

source

Model Building

Flux provides some utility functions to help you generate models in an automated fashion.

outputsize enables you to calculate the output sizes of layers like Conv when applied to input samples of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. outputsize(f, inputsize) works for all layers (including custom layers) out of the box. By default, inputsize expects the batch dimension, but you can exclude the batch size with outputsize(f, inputsize; padbatch=true) (assuming it to be one).

Using this utility function lets you automate model building for various inputs like so:

"""
    make_model(width, height, inchannels, nclasses;
               layer_config = [16, 16, 32, 32, 64, 64])

Create a CNN for a given set of configuration parameters.

# Arguments
- `width`: the input image width
- `height`: the input image height
- `inchannels`: the number of channels in the input image
- `nclasses`: the number of output classes
- `layer_config`: a vector of the number of filters per each conv layer
"""
function make_model(width, height, inchannels, nclasses;
                    layer_config = [16, 16, 32, 32, 64, 64])
  # construct a vector of conv layers programmatically
  conv_layers = [Conv((3, 3), inchannels => layer_config[1])]
  for (infilters, outfilters) in zip(layer_config, layer_config[2:end])
    push!(conv_layers, Conv((3, 3), infilters => outfilters))
  end

  # compute the output dimensions for the conv layers
  # use padbatch=true to set the batch dimension to 1
  conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true)

  # the input dimension to Dense is programatically calculated from
  #  width, height, and nchannels
  return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses))
end
Flux.outputsizeFunction
outputsize(m, inputsize::Tuple; padbatch=false)

Calculate the output size of model m given the input size. Obeys outputsize(m, size(x)) == size(m(x)) for valid input x. Keyword padbatch=true is equivalent to using (inputsize..., 1), and returns the final size including this extra batch dimension.

This should be faster than calling size(m(x)). It uses a trivial number type, and thus should work out of the box for custom layers.

If m is a Tuple or Vector, its elements are applied in sequence, like Chain(m...).

Examples

julia> using Flux: outputsize

julia> outputsize(Dense(10, 4), (10,); padbatch=true)
(4, 1)

julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32));

julia> m(randn(Float32, 10, 10, 3, 64)) |> size
(6, 6, 32, 64)

julia> outputsize(m, (10, 10, 3); padbatch=true)
(6, 6, 32, 1)

julia> outputsize(m, (10, 10, 3, 64))
(6, 6, 32, 64)

julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end
DimensionMismatch("Input channels must match! (7 vs. 3)")

julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1))
(2, 1)

julia> using LinearAlgebra: norm

julia> f(x) = x ./ norm.(eachcol(x));

julia> outputsize(f, (10, 1)) # manually specify batch size as 1
(10, 1)

julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size
(10, 1)
source

Model Abstraction

Flux.destructureFunction
destructure(m)

Flatten a model's parameters into a single weight vector.

julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)

julia> θ, re = destructure(m);

julia> θ
67-element Array{Float32,1}:
-0.1407104
...

The second return value re allows you to reconstruct the original network after making modifications to the weight vector (for example, with a hypernetwork).

julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
source
Flux.nfanFunction
nfan(n_out, n_in=1) -> Tuple
nfan(dims...)
nfan(dims::Tuple)

For a layer characterized by dimensions dims, return a tuple (fan_in, fan_out), where fan_in is the number of input neurons connected to an output one, and fan_out is the number of output neurons connected to an input one.

This function is mainly used by weight initializers, e.g., kaiming_normal.

Examples

julia> layer = Dense(10, 20)
Dense(10, 20)

julia> Flux.nfan(size(layer.W))
(10, 20)

julia> layer = Conv((3, 3), 2=>10)
Conv((3, 3), 2=>10)

julia> Flux.nfan(size(layer.weight))
(18, 90)
source

Callback Helpers

Flux.throttleFunction
throttle(f, timeout; leading=true, trailing=false)

Return a function that when invoked, will only be triggered at most once during timeout seconds.

Normally, the throttled function will run as much as it can, without ever going more than once per wait duration; but if you'd like to disable the execution on the leading edge, pass leading=false. To enable execution on the trailing edge, pass trailing=true.

source
Flux.Optimise.stopFunction
stop()

Call Flux.stop() in a callback to indicate when a callback condition is met. This will trigger the train loop to stop and exit.

Examples

cb = function ()
  accuracy() > 0.9 && Flux.stop()
end
source
Flux.Optimise.skipFunction
skip()

Call Flux.skip() in a callback to indicate when a callback condition is met. This will trigger the train loop to skip the current data point and not update with the calculated gradient.

Examples

cb = function ()
  loss() > 1e7 && Flux.skip()
end
source