.Flux
truncated_normal
function
defined in module
Flux
truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
truncated_normal([rng]; kw...) -> Function
Return an
Array{Float32}
of the given
size
where each element is drawn from a truncated normal distribution. The numbers are distributed like
filter(x -> lo<=x<=hi, mean .+ std .* randn(100))
.
The values are generated by sampling a Uniform(0, 1) (
rand()
) and then applying the inverse CDF of the truncated normal distribution. This method works best when
lo ≤ mean ≤ hi
.
julia> using Statistics
julia> Flux.truncated_normal(3, 4) |> summary
"3×4 Matrix{Float32}"
julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3)
(-2.0f0, 2.0f0)
julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
1.0f0
There are
3
methods for Flux.truncated_normal
:
The following page links back here: