DataLoader

Flux provides the DataLoader type in the Flux.Data module to handle iteration over mini-batches of data.

Flux.Data.DataLoaderType
DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)

An object that iterates over mini-batches of data, each mini-batch containing batchsize observations (except possibly the last one).

Takes as input a single data tensor, or a tuple (or a named tuple) of tensors. The last dimension in each tensor is considered to be the observation dimension.

If shuffle=true, shuffles the observations each time iterations are re-started. If partial=false, drops the last mini-batch if it is smaller than the batchsize.

The original data is preserved in the data field of the DataLoader.

Usage example:

Xtrain = rand(10, 100)
train_loader = DataLoader(Xtrain, batchsize=2)
# iterate over 50 mini-batches of size 2
for x in train_loader
    @assert size(x) == (10, 2)
    ...
end

train_loader.data   # original dataset

# similar, but yielding tuples
train_loader = DataLoader((Xtrain,), batchsize=2)
for (x,) in train_loader
    @assert size(x) == (10, 2)
    ...
end

Xtrain = rand(10, 100)
Ytrain = rand(100)
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
for epoch in 1:100
    for (x, y) in train_loader
        @assert size(x) == (10, 2)
        @assert size(y) == (2,)
        ...
    end
end

# train for 10 epochs
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)

# can use NamedTuple to name tensors
train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
for datum in train_loader
    @assert size(datum.images) == (10, 2)
    @assert size(datum.labels) == (2,)
end
source