Utility Functions

Flux contains some utility functions for working with data; these functions help create inputs for your models or batch your dataset. Other functions can be used to initialize your layers or to regularly execute callback functions.

Working with Data

Flux.unsqueezeFunction
unsqueeze(xs, dim)

Return xs reshaped into an Array one dimensionality higher than xs, where dim indicates in which dimension xs is extended.

Examples

julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Array{Array{Int64,1},1}:
 [1, 2]
 [3, 4]
 [5, 6]

julia> Flux.unsqueeze(xs, 1)
1×3 Array{Array{Int64,1},2}:
 [1, 2]  [3, 4]  [5, 6]

julia> Flux.unsqueeze([1 2; 3 4], 2)
2×1×2 Array{Int64,3}:
[:, :, 1] =
 1
 3

[:, :, 2] =
 2
 4
source
Flux.stackFunction
stack(xs, dim)

Concatenate the given Array of Arrays xs into a single Array along the given dimension dim.

Examples

julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Array{Array{Int64,1},1}:
 [1, 2]
 [3, 4]
 [5, 6]

julia> Flux.stack(xs, 1)
3×2 Array{Int64,2}:
 1  2
 3  4
 5  6

julia> cat(xs, dims=1)
3-element Array{Array{Int64,1},1}:
 [1, 2]
 [3, 4]
 [5, 6]
source
Flux.unstackFunction
unstack(xs, dim)

Unroll the given xs into an Array of Arrays along the given dimension dim.

Examples

julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2)
4-element Array{Array{Int64,1},1}:
 [1, 2]
 [3, 4]
 [5, 6]
 [7, 8]
source
Flux.chunkFunction
chunk(xs, n)

Split xs into n parts.

Examples

julia> Flux.chunk(1:10, 3)
3-element Array{UnitRange{Int64},1}:
 1:4
 5:8
 9:10

julia> Flux.chunk(collect(1:10), 3)
3-element Array{SubArray{Int64,1,Array{Int64,1},Tuple{UnitRange{Int64}},true},1}:
 [1, 2, 3, 4]
 [5, 6, 7, 8]
 [9, 10]
source
Flux.frequenciesFunction
frequencies(xs)

Count the number of times that each element of xs appears.

Examples

julia> Flux.frequencies(['a','b','b'])
Dict{Char,Int64} with 2 entries:
  'a' => 1
  'b' => 2
source
Flux.batchFunction
batch(xs)

Batch the arrays in xs into a single array.

Examples

julia> Flux.batch([[1,2,3],[4,5,6]])
3×2 Array{Int64,2}:
 1  4
 2  5
 3  6
source
Flux.batchseqFunction
batchseq(seqs, pad)

Take a list of N sequences, and turn them into a single sequence where each item is a batch of N. Short sequences will be padded by pad.

Examples

julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0)
3-element Array{Array{Int64,1},1}:
 [1, 4]
 [2, 5]
 [3, 0]
source
Base.rpadMethod

Return the given sequence padded with p up to a maximum length of n.

Examples

julia> rpad([1, 2], 4, 0)
4-element Array{Int64,1}:
 1
 2
 0
 0

julia> rpad([1, 2, 3], 2, 0)
3-element Array{Int64,1}:
 1
 2
 3
source

Layer Initialization

These are primarily useful if you are planning to write your own layers. Flux initializes convolutional layers and recurrent cells with glorot_uniform by default. To change the default on an applicable layer, pass the desired function with the init keyword. For example:

julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
Conv((3, 3), 1=>8, relu)
Flux.glorot_uniformFunction
glorot_uniform(dims...)

Return an Array of size dims containing random variables taken from a uniform distribution in the interval $[-x, x]$, where x = sqrt(24 / sum(dims)) / 2.

Examples

julia> Flux.glorot_uniform(2, 3)
2×3 Array{Float32,2}:
 0.601094  -0.57414   -0.814925
 0.900868   0.805994   0.057514
source
Flux.glorot_normalFunction
glorot_normal(dims...)

Return an Array of size dims containing random variables taken from a normal distribution with mean 0 and standard deviation (2 / sum(dims)).

Examples

julia> Flux.glorot_normal(3, 2)
3×2 Array{Float32,2}:
  0.429505  -0.0852891
  0.523935   0.371009
 -0.223261   0.188052
source

Model Abstraction

Flux.destructureFunction
destructure(m)

Flatten a model's parameters into a single weight vector.

julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)

julia> θ, re = destructure(m);

julia> θ
67-element Array{Float32,1}:
-0.1407104
...

The second return value re allows you to reconstruct the original network after making modifications to the weight vector (for example, with a hypernetwork).

julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
source

Callback Helpers

Flux.throttleFunction
throttle(f, timeout; leading=true, trailing=false)

Return a function that when invoked, will only be triggered at most once during timeout seconds.

Normally, the throttled function will run as much as it can, without ever going more than once per wait duration; but if you'd like to disable the execution on the leading edge, pass leading=false. To enable execution on the trailing edge, pass trailing=true.

source
Flux.Optimise.stopFunction
stop()

Call Flux.stop() in a callback to indicate when a callback condition is met. This will trigger the train loop to stop and exit.

Examples

cb = function ()
  accuracy() > 0.9 && Flux.stop()
end
source