Shape Inference

To help you generate models in an automated fashion, Flux.outputsize lets you calculate the size returned produced by layers for a given size input. This is especially useful for layers like Conv.

It works by passing a "dummy" array into the model that preserves size information without running any computation. outputsize(f, inputsize) works for all layers (including custom layers) out of the box. By default, inputsize expects the batch dimension, but you can exclude the batch size with outputsize(f, inputsize; padbatch=true) (assuming it to be one).

Using this utility function lets you automate model building for various inputs like so:

"""
    make_model(width, height, inchannels, nclasses;
               layer_config = [16, 16, 32, 32, 64, 64])

Create a CNN for a given set of configuration parameters.

# Arguments
- `width`: the input image width
- `height`: the input image height
- `inchannels`: the number of channels in the input image
- `nclasses`: the number of output classes
- `layer_config`: a vector of the number of filters per each conv layer
"""
function make_model(width, height, inchannels, nclasses;
                    layer_config = [16, 16, 32, 32, 64, 64])
  # construct a vector of conv layers programmatically
  conv_layers = [Conv((3, 3), inchannels => layer_config[1])]
  for (infilters, outfilters) in zip(layer_config, layer_config[2:end])
    push!(conv_layers, Conv((3, 3), infilters => outfilters))
  end

  # compute the output dimensions for the conv layers
  # use padbatch=true to set the batch dimension to 1
  conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true)

  # the input dimension to Dense is programatically calculated from
  #  width, height, and nchannels
  return Chain(conv_layers..., Dense(prod(conv_outsize) => nclasses))
end
Flux.outputsizeFunction
outputsize(m, x_size, y_size, ...; padbatch=false)

For model or layer m accepting multiple arrays as input, this returns size(m((x, y, ...))) given size_x = size(x), etc.

Examples

julia> x, y = rand(Float32, 5, 64), rand(Float32, 7, 64);

julia> par = Parallel(vcat, Dense(5, 9), Dense(7, 11));

julia> Flux.outputsize(par, (5, 64), (7, 64))
(20, 64)

julia> m = Chain(par, Dense(20, 13), softmax);

julia> Flux.outputsize(m, (5,), (7,); padbatch=true)
(13, 1)

julia> par(x, y) == par((x, y)) == Chain(par, identity)((x, y))
true

Notice that Chain only accepts multiple arrays as a tuple, while Parallel also accepts them as multiple arguments; outputsize always supplies the tuple.

source