.Fluxtruncated_normal

function defined in module Flux


			truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
truncated_normal([rng]; kw...) -> Function

Return an Array{Float32} of the given size where each element is drawn from a truncated normal distribution. The numbers are distributed like filter(x -> lo<=x<=hi, mean .+ std .* randn(100)).

The values are generated by sampling a Uniform(0, 1) ( rand()) and then applying the inverse CDF of the truncated normal distribution. This method works best when lo ≤ mean ≤ hi.

Examples


			julia> using Statistics

julia> Flux.truncated_normal(3, 4) |> summary
"3×4 Matrix{Float32}"

julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3)
(-2.0f0, 2.0f0)

julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
1.0f0
Methods

There are 3 methods for Flux.truncated_normal: