.Flux
modules
function
defined in module
Flux
modules(m)
Return an iterator over non-leaf objects that can be reached by recursing
m
over the children given by
functor
.
Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters (e.g. the weights but not the biases).
julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu));
julia> m2 = Chain(m1, Dense(64, 10))
Chain(
Chain(
Dense(784 => 64), # 50_240 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
),
Dense(64 => 10), # 650 parameters
) # Total: 6 trainable arrays, 51_018 parameters,
# plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB.
julia> Flux.modules(m2)
7-element Vector{Any}:
Chain(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10)) # 51_018 parameters, plus 128 non-trainable
(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))
Chain(Dense(784 => 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable
(Dense(784 => 64), BatchNorm(64, relu))
Dense(784 => 64) # 50_240 parameters
BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable
Dense(64 => 10) # 650 parameters
julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
julia> L2(m2) isa Float32
true
There is
1
method for Flux.modules
:
The following page links back here: