Linear(; σ=Flux.relu)

MLJFlux builder that constructs a fully connected two layer network with activation function σ. The number of input and output nodes is determined from the data. Weights are initialized using Flux.glorot_uniform(rng), where rng is inferred from the rng field of the MLJFlux model.

Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid)

MLJFlux builder that constructs a full-connected three-layer network using n_hidden nodes in the hidden layer and the specified dropout (defaulting to 0.5). An activation function σ is applied between the hidden and final layers. If n_hidden=0 (the default) then n_hidden is the geometric mean of the number of input and output nodes. The number of input and output nodes is determined from the data.

Each layer is initialized using Flux.glorot_uniform(rng), where rng is inferred from the rng field of the MLJFlux model.

MLP(; hidden=(100,), σ=Flux.relu)

MLJFlux builder that constructs a Multi-layer perceptron network. The ith element of hidden represents the number of neurons in the ith hidden layer. An activation function σ is applied between each layer.

Each layer is initialized using Flux.glorot_uniform(rng), where rng is inferred from the rng field of the MLJFlux model.

@builder neural_net

Creates a builder for neural_net. The variables rng, n_in, n_out and n_channels can be used to create builders for any random number generator rng, input and output sizes n_in and n_out and number of input channels n_channels.


julia> import MLJFlux: @builder;

julia> nn = NeuralNetworkRegressor(builder = @builder(Chain(Dense(n_in, 64, relu),
                                                            Dense(64, 32, relu),
                                                            Dense(32, n_out))));

julia> conv_builder = @builder begin
           front = Chain(Conv((3, 3), n_channels => 16), Flux.flatten)
           d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
           Chain(front, Dense(d, n_out));

julia> conv_nn = NeuralNetworkRegressor(builder = conv_builder);