DataLoader
Flux provides the DataLoader
type in the Flux.Data
module to handle iteration over mini-batches of data.
Flux.Data.DataLoader
— TypeDataLoader(data...; batchsize=1, shuffle=false, partial=true)
An object that iterates over mini-batches of data
, each mini-batch containing batchsize
observations (except possibly the last one).
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in supervised learning. The last dimension in each tensor is considered to be the observation dimension.
If shuffle=true
, shuffles the observations each time iterations are re-started. If partial=false
, drops the last mini-batch if it is smaller than the batchsize.
The original data is preserved as a tuple in the data
field of the DataLoader.
Example usage:
Xtrain = rand(10, 100)
train_loader = DataLoader(Xtrain, batchsize=2)
# iterate over 50 mini-batches of size 2
for x in train_loader
@assert size(x) == (10, 2)
...
end
train_loader.data # original dataset
Xtrain = rand(10, 100)
Ytrain = rand(100)
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
for epoch in 1:100
for (x, y) in train_loader
@assert size(x) == (10, 2)
@assert size(y) == (2,)
...
end
end
# train for 10 epochs
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)