DataLoader

Flux provides the DataLoader type in the Flux.Data module to handle iteration over mini-batches of data.

Flux.Data.DataLoaderType
Flux.DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)

An object that iterates over mini-batches of data, each mini-batch containing batchsize observations (except possibly the last one).

Takes as input a single data tensor, or a tuple (or a named tuple) of tensors. The last dimension in each tensor is the observation dimension, i.e. the one divided into mini-batches.

If shuffle=true, it shuffles the observations each time iterations are re-started. If partial=false and the number of observations is not divisible by the batchsize, then the last mini-batch is dropped.

The original data is preserved in the data field of the DataLoader.

Examples

julia> Xtrain = rand(10, 100);

julia> array_loader = Flux.DataLoader(Xtrain, batchsize=2);

julia> for x in array_loader
         @assert size(x) == (10, 2)
         # do something with x, 50 times
       end

julia> array_loader.data === Xtrain
true

julia> tuple_loader = Flux.DataLoader((Xtrain,), batchsize=2);  # similar, but yielding 1-element tuples

julia> for x in tuple_loader
         @assert x isa Tuple{Matrix}
         @assert size(x[1]) == (10, 2)
       end

julia> Ytrain = rand('a':'z', 100);  # now make a DataLoader yielding 2-element named tuples

julia> train_loader = Flux.DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true);

julia> for epoch in 1:100
         for (x, y) in train_loader  # access via tuple destructuring
           @assert size(x) == (10, 5)
           @assert size(y) == (5,)
           # loss += f(x, y) # etc, runs 100 * 20 times
         end
       end

julia> first(train_loader).label isa Vector{Char}  # access via property name
true

julia> first(train_loader).label == Ytrain[1:5]  # because of shuffle=true
false

julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize=30))  # partial=false would omit last
10×30 Matrix{Int8}
10×30 Matrix{Int8}
10×4 Matrix{Int8}
source