DataLoader
Flux provides the DataLoader
type in the Flux.Data
module to handle iteration over mini-batches of data.
Flux.Data.DataLoader
— TypeFlux.DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
An object that iterates over mini-batches of data
, each mini-batch containing batchsize
observations (except possibly the last one).
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors. The last dimension in each tensor is the observation dimension, i.e. the one divided into mini-batches.
If shuffle=true
, it shuffles the observations each time iterations are re-started. If partial=false
and the number of observations is not divisible by the batchsize, then the last mini-batch is dropped.
The original data is preserved in the data
field of the DataLoader.
Examples
julia> Xtrain = rand(10, 100);
julia> array_loader = Flux.DataLoader(Xtrain, batchsize=2);
julia> for x in array_loader
@assert size(x) == (10, 2)
# do something with x, 50 times
end
julia> array_loader.data === Xtrain
true
julia> tuple_loader = Flux.DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples
julia> for x in tuple_loader
@assert x isa Tuple{Matrix}
@assert size(x[1]) == (10, 2)
end
julia> Ytrain = rand('a':'z', 100); # now make a DataLoader yielding 2-element named tuples
julia> train_loader = Flux.DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true);
julia> for epoch in 1:100
for (x, y) in train_loader # access via tuple destructuring
@assert size(x) == (10, 5)
@assert size(y) == (5,)
# loss += f(x, y) # etc, runs 100 * 20 times
end
end
julia> first(train_loader).label isa Vector{Char} # access via property name
true
julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true
false
julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize=30)) # partial=false would omit last
10×30 Matrix{Int8}
10×30 Matrix{Int8}
10×4 Matrix{Int8}