Advanced Model Building and Customisation
Here we will try and describe usage of some more advanced features that Flux provides to give more control over model building.
Customising Parameter Collection for a Model
Taking reference from our example Affine
layer from the basics.
By default all the fields in the Affine
type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, it is possible to mark the fields of our layers that are trainable in two ways.
The first way of achieving this is through overloading the trainable
function.
julia> @functor Affine
julia> a = Affine(rand(3,3), rand(3))
Affine{Array{Float64,2},Array{Float64,1}}([0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955])
julia> Flux.params(a) # default behavior
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955]])
julia> Flux.trainable(a::Affine) = (a.W, a.b,)
julia> Flux.params(a)
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297]])
Only the fields returned by trainable
will be collected as trainable parameters of the layer when calling Flux.params
.
Another way of achieving this is through the @functor
macro directly. Here, we can mark the fields we are interested in by grouping them in the second argument:
Flux.@functor Affine (W,)
However, doing this requires the struct
to have a corresponding constructor that accepts those parameters.
Freezing Layer Parameters
When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to params
.
Consider a simple multi-layer perceptron model where we want to avoid optimising the first two Dense
layers. We can obtain this using the slicing features Chain
provides:
m = Chain(
Dense(784, 64, relu),
Dense(64, 64, relu),
Dense(32, 10)
)
ps = Flux.params(m[3:end])
The Zygote.Params
object ps
now holds a reference to only the parameters of the layers passed to it.
During training, the gradients will only be computed for (and applied to) the last Dense
layer, therefore only that would have its parameters changed.
Flux.params
also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second Dense
layer in the previous example. It would look something like this:
Flux.params(m[1], m[3:end])
Sometimes, a more fine-tuned control is needed. We can freeze a specific parameter of a specific layer which already entered a Params
object ps
, by simply deleting it from ps
:
ps = params(m)
delete!(ps, m[2].b)