Random Weight Initialisation
Flux initialises convolutional layers and recurrent cells with glorot_uniform
by default. Most layers accept a function as an init
keyword, which replaces this default. For example:
julia> conv = Conv((3, 3), 3 => 2, relu; init=Flux.glorot_normal)
Conv((3, 3), 3 => 2, relu) # 56 parameters
julia> conv.bias
2-element Vector{Float32}:
0.0
0.0
Note that init
creates the weight array, but not the bias vector.
Many of the initialisation functions accept keywords such as gain
, and a random number generator. To make it easy to pass these to layers, there are methods which return a function:
julia> Dense(4 => 5, tanh; init=Flux.glorot_uniform(gain=2))
Dense(4 => 5, tanh) # 25 parameters
julia> Dense(4 => 5, tanh; init=Flux.randn32(MersenneTwister(1)))
Dense(4 => 5, tanh) # 25 parameters
Initialisation functions
Flux.glorot_uniform
— Functionglorot_uniform([rng], size...; gain = 1) -> Array
glorot_uniform([rng]; kw...) -> Function
Return an Array{Float32}
of the given size
containing random numbers drawn from a uniform distribution on the interval $[-x, x]$, where x = gain * sqrt(6 / (fan_in + fan_out))
.
This method is described in [1] and also known as Xavier initialization.
Examples
julia> Flux.glorot_uniform(3, 4) |> summary
"3×4 Matrix{Float32}"
julia> round.(extrema(Flux.glorot_uniform(10, 100)), digits=3)
(-0.233f0, 0.233f0)
julia> round.(extrema(Flux.glorot_uniform(100, 10)), digits=3)
(-0.234f0, 0.233f0)
julia> round.(extrema(Flux.glorot_uniform(100, 100)), digits=3)
(-0.173f0, 0.173f0)
julia> Dense(3 => 2, tanh; init = Flux.glorot_uniform(MersenneTwister(1)))
Dense(3 => 2, tanh) # 8 parameters
julia> ans.bias
2-element Vector{Float32}:
0.0
0.0
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
Flux.glorot_normal
— Functionglorot_normal([rng], size...; gain = 1) -> Array
glorot_normal([rng]; kw...) -> Function
Return an Array{Float32}
of the given size
containing random numbers drawn from a normal distribution with standard deviation gain * sqrt(2 / (fan_in + fan_out))
, using nfan
.
This method is described in [1] and also known as Xavier initialization.
Examples
julia> using Statistics
julia> round(std(Flux.glorot_normal(10, 1000)), digits=3)
0.044f0
julia> round(std(Flux.glorot_normal(1000, 10)), digits=3)
0.045f0
julia> round(std(Flux.glorot_normal(1000, 1000)), digits=3)
0.032f0
julia> Dense(10 => 1000, tanh; init = Flux.glorot_normal(gain=100))
Dense(10 => 1000, tanh) # 11_000 parameters
julia> round(std(ans.weight), sigdigits=3)
4.45f0
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
Flux.kaiming_uniform
— Functionkaiming_uniform([rng], size...; gain = √2) -> Array
kaiming_uniform([rng]; kw...) -> Function
Return an Array{Float32}
of the given size
containing random numbers drawn from a uniform distribution on the interval [-x, x]
, where x = gain * sqrt(3/fan_in)
using nfan
.
This method is described in [1] and also known as He initialization.
Examples
julia> round.(extrema(Flux.kaiming_uniform(100, 10)), digits=3)
(-0.774f0, 0.773f0)
julia> round.(extrema(Flux.kaiming_uniform(10, 100)), digits=3)
(-0.243f0, 0.245f0)
julia> round.(extrema(Flux.kaiming_uniform(100, 100)), digits=3)
(-0.245f0, 0.245f0)
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Flux.kaiming_normal
— Functionkaiming_normal([rng], size...; gain = √2) -> Array
kaiming_normal([rng]; kw...) -> Function
Return an Array{Float32}
of the given size
containing random numbers taken from a normal distribution standard deviation gain / sqrt(fan_in)
, using nfan
.
This method is described in [1] and also known as He initialization.
Examples
julia> using Statistics
julia> round(std(Flux.kaiming_normal(10, 1000)), digits=3)
0.044f0
julia> round(std(Flux.kaiming_normal(1000, 10)), digits=3)
0.45f0
julia> round(std(Flux.kaiming_normal(1000, 1000)), digits=3)
0.045f0
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Flux.truncated_normal
— Functiontruncated_normal([rng], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
truncated_normal([rng]; kw...) -> Function
Return an Array{Float32}
of the given size
where each element is drawn from a truncated normal distribution. The numbers are distributed like filter(x -> lo<=x<=hi, mean .+ std .* randn(100))
.
The values are generated by sampling a Uniform(0, 1) (rand()
) and then applying the inverse CDF of the truncated normal distribution. This method works best when lo ≤ mean ≤ hi
.
Examples
julia> using Statistics
julia> Flux.truncated_normal(3, 4) |> summary
"3×4 Matrix{Float32}"
julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3)
(-2.0f0, 2.0f0)
julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
1.0f0
Flux.lecun_normal
— Functionlecun_normal([rng], size...) -> Array
lecun_normal([rng]; kw...) -> Function
Return an Array{Float32}
of the given size
containing random numbers drawn from a truncated normal distribution centered on 0 with stddev sqrt(1 / fan_in)
, where fan_in
is the number of input units in the weight tensor.
Examples
julia> using Statistics
julia> round(std(Flux.lecun_normal(10, 1000)), digits=3)
0.032f0
julia> round(std(Flux.lecun_normal(1000, 10)), digits=3)
0.32f0
julia> round(std(Flux.lecun_normal(1000, 1000)), digits=3)
0.032f0
julia> Dense(10 => 1000, selu; init = Flux.lecun_normal())
Dense(10 => 1000, selu) # 11_000 parameters
julia> round(std(ans.weight), digits=3)
0.313f0
References
[1] Lecun, Yann, et al. "Efficient backprop." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 9-48.
Flux.orthogonal
— Functionorthogonal([rng], size...; gain = 1) -> Array
orthogonal([rng]; kw...) -> Function
Return an Array{Float32}
of the given size
which is a (semi) orthogonal matrix, as described in [1].
Cannot construct a vector, i.e. length(size) == 1
is forbidden. For length(size) > 2
, a prod(size[1:(end - 1)])
by size[end]
orthogonal matrix is computed before reshaping it to the original dimensions.
Examples
julia> W = Flux.orthogonal(5, 7);
julia> summary(W)
"5×7 Matrix{Float32}"
julia> W * W' ≈ I(5)
true
julia> W2 = Flux.orthogonal(7, 5);
julia> W2 * W2' ≈ I(7)
false
julia> W2' * W2 ≈ I(5)
true
julia> W3 = Flux.orthogonal(3, 3, 2, 4);
julia> transpose(reshape(W3, :, 4)) * reshape(W3, :, 4) ≈ I(4)
true
References
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
Flux.sparse_init
— Functionsparse_init([rng], rows, cols; sparsity, std = 0.01) -> Array
sparse_init([rng]; kw...) -> Function
Return a Matrix{Float32}
of size rows, cols
where each column contains a fixed fraction of zero elements given by sparsity
. Non-zero elements are normally distributed with a mean of zero and standard deviation std
.
This method is described in [1].
Examples
julia> count(iszero, Flux.sparse_init(10, 10, sparsity=1/5))
20
julia> sum(0 .== Flux.sparse_init(10, 11, sparsity=0.9), dims=1)
1×11 Matrix{Int64}:
9 9 9 9 9 9 9 9 9 9 9
julia> Dense(3 => 10, tanh; init=Flux.sparse_init(sparsity=0.5))
Dense(3 => 10, tanh) # 40 parameters
julia> count(iszero, ans.weight, dims=1)
1×3 Matrix{Int64}:
5 5 5
References
[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th International Conference on International Conference on Machine Learning. 2010.
Flux.identity_init
— Functionidentity_init(size...; gain=1, shift=0) -> Array
identity_init(; kw...) -> Function
Return an Array{Float32}
of the given size
which yields an identity mapping when used as parameters in most Flux layers. Use gain
to scale the identity by a constant.
Often useful in the context of transfer learning, i.e when one wants to add more capacity to a model but start from the same mapping.
Has the following behaviour
- 1D: A
Vector
ofzeros
(useful for an identity bias) - 2D: An identity matrix (useful for an identity matrix multiplication)
- More than 2D: A dense block array of center tap spatial filters (useful for an identity convolution)
Some caveats:
Not all layers will be identity mapping when used with this init. Exceptions include recurrent layers and normalization layers.
Layers must have
input_size == output_size
for identity mapping to be possible. When this is not the case, extra dimensions of the array are padded with zeros.For convolutional layers, in addition to the above, the kernel sizes must also be odd and padding must be applied so that output feature maps have the same size as input feature maps, e.g by using
SamePad
.
Use keyword shift
(integer or tuple) to apply circular shift to the output, equivalent to Base.circshift(identity_init(size...), shift)
.
For consistency with other initialisers, it accepts rng::AbstractRNG
as an optional first argument. But this is ignored, since the result is not random.
Examples
julia> Flux.identity_init(3,5)
3×5 Matrix{Float32}:
1.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0
0.0 0.0 1.0 0.0 0.0
julia> Dense(5 => 3, relu, init=Flux.identity_init)([1,-2,3,-4,5])
3-element Vector{Float32}:
1.0
0.0
3.0
julia> Flux.identity_init(3,3,2; gain=100)
3×3×2 Array{Float32, 3}:
[:, :, 1] =
0.0 0.0 0.0
100.0 0.0 0.0
0.0 0.0 0.0
[:, :, 2] =
0.0 0.0 0.0
0.0 100.0 0.0
0.0 0.0 0.0
julia> x4 = cat([1 2 3; 4 5 6; 7 8 9]; dims=4);
julia> Conv((2,2), 1 => 1, init=Flux.identity_init(gain=10), pad=SamePad())(x4)
3×3×1×1 Array{Float32, 4}:
[:, :, 1, 1] =
10.0 20.0 30.0
40.0 50.0 60.0
70.0 80.0 90.0
Flux.ones32
— Functionones32(size...) = ones(Float32, size...)
Return an Array{Float32}
of the given size
filled with 1s.
Flux.zeros32
— Functionzeros32(size...) = zeros(Float32, size...)
Return an Array{Float32}
of the given size
filled with 0s.
Flux.rand32
— Functionrand32([rng], size...)
Return an Array{Float32}
of the given size
, filled like rand
. When the size is not provided, rand32(rng::AbstractRNG)
returns a function.
Flux.randn32
— Functionrandn32([rng], size...)
Return an Array{Float32}
of the given size
, filled like randn
. When the size is not provided, randn32(rng::AbstractRNG)
returns a function.
Flux.create_bias
— Functioncreate_bias(weights, bias, size...)
Return a bias parameter for a layer, based on the value given to the constructor's keyword bias=bias
.
bias == true
creates a trainable array of the given size, of the same type asweights
, initialised to zero.bias == false
returnsfalse
, which is understood by AD to be non-differentiable.bias::AbstractArray
uses the array provided, provided it has the correct size. It will also correct theeltype
to match that ofweights
.
These functions call:
Flux.rng_from_array
— Functionrng_from_array(x)
Create an instance of the RNG most appropriate for x
. As an example, if x
is aCuArray
, it will return a CUDA.default_rng()
. If x
is an Array
instead, it will return a Random.default_rng()
.
Flux.nfan
— Functionnfan(n_out, n_in=1) -> Tuple
nfan(dims...)
nfan(dims::Tuple)
For a layer characterized by dimensions dims
, return a tuple (fan_in, fan_out)
, where fan_in
is the number of input neurons connected to an output one, and fan_out
is the number of output neurons connected to an input one.
This function is mainly used by weight initializers, e.g., kaiming_normal
.
Examples
julia> layer = Dense(10, 20);
julia> Flux.nfan(size(layer.weight))
(10, 20)
julia> layer = Conv((3, 3), 2=>10);
julia> Flux.nfan(size(layer.weight))
(18, 90)
Changing the type of all parameters
The default eltype
for models is Float32
since models are often trained/run on GPUs. The eltype
of model m
can be changed to Float64
by f64(m)
:
Flux.f64
— Functionf64(m)
Converts the eltype
of model's floating point parameters to Float64
. Recurses into structs marked with @layer
.
Flux.f32
— Functionf32(m)
Converts the eltype
of model's floating point parameters to Float32
(which is Flux's default). Recurses into structs marked with @layer
.
Flux.f16
— Functionf16(m)
Converts the eltype
of model's floating point parameters to Float16
. Recurses into structs marked with @layer
.
Support for Float16
is limited on many CPUs. Julia may convert to Float32
for each operation, which is slow.
Example
julia> m = Chain(Dense(784, 2048, relu), Dense(2048, 10)) # all Float32
Chain(
Dense(784 => 2048, relu), # 1_607_680 parameters
Dense(2048 => 10), # 20_490 parameters
) # Total: 4 arrays, 1_628_170 parameters, 6.211 MiB.
julia> m |> f16 # takes half the memory
Chain(
Dense(784 => 2048, relu), # 1_607_680 parameters
Dense(2048 => 10), # 20_490 parameters
) # Total: 4 arrays, 1_628_170 parameters, 3.106 MiB.