batchview.jl

MLUtils/batchview.jl is a source file in module MLUtils

			
			
			
			
			
			"""
			

			    BatchView(data, batchsize; partial=true, collate=nothing)

			    BatchView(data; batchsize=1, partial=true, collate=nothing)

			

			Create a view of the given `data` that represents it as a vector

			of batches. Each batch will contain an equal amount of

			observations in them. The batch-size

			can be specified using the  parameter `batchsize`.

			In the case that the size of the dataset is not dividable by the

			specified `batchsize`, the remaining observations will

			be ignored if `partial=false`. If  `partial=true` instead

			the last batch-size can be slightly smaller.

			

			Note that any data access is delayed until `getindex` is called.

			

			If used as an iterator, the object will iterate over the dataset

			once, effectively denoting an epoch.

			

			For `BatchView` to work on some data structure, the type of the

			given variable `data` must implement the data container

			interface. See [`ObsView`](@ref) for more info.

			

			# Arguments

			

			- **`data`** : The object describing the dataset. Can be of any

			    type as long as it implements [`getobs`](@ref) and

			    [`numobs`](@ref) (see Details for more information).

			

			- **`batchsize`** : The batch-size of each batch.

			    It is the number of observations that each batch must contain

			    (except possibly for the last one).

			

			- **`partial`** : If `partial=false` and the number of observations is

			    not divisible by the batch-size, then the last mini-batch is dropped.

			

			- **`collate`**: Batching behavior. If `nothing` (default), a batch

			    is `getobs(data, indices)`. If `false`, each batch is

			    `[getobs(data, i) for i in indices]`. When `true`, applies [`batch`](@ref)

			    to the vector of observations in a batch, recursively collating

			    arrays in the last dimensions. See [`batch`](@ref) for more information

			    and examples.

			

			# Examples

			

			```julia

			using MLUtils

			X, Y = MLUtils.load_iris()

			

			A = BatchView(X, batchsize=30)

			@assert typeof(A) <: BatchView <: AbstractVector

			@assert eltype(A) <: SubArray{Float64,2}

			@assert length(A) == 5 # Iris has 150 observations

			@assert size(A[1]) == (4,30) # Iris has 4 features

			

			# 5 batches of size 30 observations

			for x in BatchView(X, batchsize=30)

			    @assert typeof(x) <: SubArray{Float64,2}

			    @assert numobs(x) === 30

			end

			

			# 7 batches of size 20 observations

			# Note that the iris dataset has 150 observations,

			# which means that with a batchsize of 20, the last

			# 10 observations will be ignored

			for (x, y) in BatchView((X, Y), batchsize=20, partial=false)

			    @assert typeof(x) <: SubArray{Float64,2}

			    @assert typeof(y) <: SubArray{String,1}

			    @assert numobs(x) == numobs(y) == 20

			end

			

			# collate tuple observations

			for (x, y) in BatchView((rand(10, 3), ["a", "b", "c"]), batchsize=2, collate=true, partial=false)

			    @assert size(x) == (10, 2)

			    @assert size(y) == (2,)

			end

			

			

			# randomly assign observations to one and only one batch.

			for (x, y) in BatchView(shuffleobs((X, Y)), batchsize=20)

			    @assert typeof(x) <: SubArray{Float64,2}

			    @assert typeof(y) <: SubArray{String,1}

			end

			```

			"""
			

			
			
			struct
			
			 
			

	
			BatchView
			{
			TElem
			,
			TData
			,
			TCollate
			}
			 
			<:
			 

	
			AbstractDataContainer
			
			
    
			
			data
			::
			TData
			
    
			

	
			batchsize
			::
			Int
			
    
			
			count
			::
			Int
			
    
			
			partial
			::
			Bool
			

			end
			

			

			
			function
			 
			
			

	
			BatchView
			(
			
			data
			::
			T
			
			;
			 
			
			

	
			batchsize
			::
			Int
			=
			1
			,
			 
			
			
			partial
			::
			Bool
			=
			true
			,
			 
			
			collate
			=
			
			Val
			(
			nothing
			)
			)
			 
			where
			 
			{
			T
			}
			
			
    
			
			n
			 
			=
			 
			

	
			numobs
			(
			data
			)
			
    
			
			if
			
			 
			n
			 
			<
			 

	
			batchsize
			
			
        
			
			@
			warn
			 
			
			"
			Number of observations less than batch-size, decreasing the batch-size to 
			$
			n
			"
			
        
			

	
			batchsize
			 
			=
			 
			n
			
    
			end
			
    
			
			collate
			 
			=
			
			
			 
			collate
			 
			isa
			 
			Val
			 
			?
			 
			collate
			 
			:
			 
			
			Val
			(
			collate
			)
			
    
			
			if
			 
			
			!
			(
			
			collate
			 
			 
			
			(
			
			Val
			(
			nothing
			)
			,
			 
			
			Val
			(
			true
			)
			,
			 
			
			Val
			(
			false
			)
			)
			)
			
			
        
			
			throw
			(
			
			ArgumentError
			(
			
			"
			`collate` must be one of `nothing`, `true` or `false`.
			"
			)
			)
			
    
			end
			
    
			
			E
			 
			=
			 
			
			_batchviewelemtype
			(
			data
			,
			 
			collate
			)
			
    
			
			count
			 
			=
			
			 
			partial
			 
			?
			 
			
			cld
			(
			n
			,
			 

	
			batchsize
			)
			 
			:
			 
			
			fld
			(
			n
			,
			 

	
			batchsize
			)
			
    
			
			

	
			BatchView
			{
			E
			,
			T
			,
			
			typeof
			(
			collate
			)
			}
			(
			data
			,
			 

	
			batchsize
			,
			 
			count
			,
			 
			partial
			)
			

			end
			

			

			
			
			
			_batchviewelemtype
			(
			
			::
			TData
			,
			 
			
			::
			
			Val
			{
			nothing
			}
			)
			 
			where
			 
			TData
			 
			=
			
    
			
			
			
			Core
			.
			
			Compiler
			.
			
			return_type
			(

	
			getobs
			,
			 
			
			Tuple
			{
			TData
			,
			 
			
			UnitRange
			{
			Int
			}
			}
			)
			

			
			
			
			_batchviewelemtype
			(
			
			::
			TData
			,
			 
			
			::
			
			Val
			{
			false
			}
			)
			 
			where
			 
			TData
			 
			=
			
    
			
			Vector
			{
			
			
			
			Core
			.
			
			Compiler
			.
			
			return_type
			(

	
			getobs
			,
			 
			
			Tuple
			{
			TData
			,
			 
			Int
			}
			)
			}
			

			
			
			_batchviewelemtype
			(
			data
			,
			 
			
			::
			
			Val
			{
			true
			}
			)
			 
			=
			
    
			
			
			
			Core
			.
			
			Compiler
			.
			
			return_type
			(

	
			batch
			,
			 
			
			Tuple
			{
			
			_batchviewelemtype
			(
			data
			,
			 
			
			Val
			(
			false
			)
			)
			}
			)
			

			

			
			
			
			"""
			

			    batchsize(data::BatchView) -> Int

			

			Return the fixed size of each batch in `data`.

			

			# Examples

			

			```julia

			using MLUtils

			X, Y = MLUtils.load_iris()

			

			A = BatchView(X, batchsize=30)

			@assert batchsize(A) == 30

			```

			"""
			

			
			

	
			batchsize
			(
			
			A
			::

	
			BatchView
			)
			 
			=
			 
			
			A
			.
			

	
			batchsize
			

			

			
			
			
			Base
			.
			
			length
			(
			
			A
			::

	
			BatchView
			)
			 
			=
			 
			
			A
			.
			
			count
			

			

			
			
			Base
			.
			
			@
			propagate_inbounds
			 
			
			function
			 
			

	
			getobs
			(
			
			A
			::

	
			BatchView
			)
			
			
    
			
			return
			 
			
			_getbatch
			(
			A
			,
			 
			
			1
			:
			

	
			numobs
			(
			
			A
			.
			
			data
			)
			)
			

			end
			

			

			
			
			Base
			.
			
			@
			propagate_inbounds
			 
			
			function
			 
			
			
			Base
			.
			
			getindex
			(
			
			A
			::

	
			BatchView
			,
			 
			
			i
			::
			Int
			)
			
			
    
			
			obsindices
			 
			=
			 
			
			_batchrange
			(
			A
			,
			 
			i
			)
			
    
			
			_getbatch
			(
			A
			,
			 
			obsindices
			)
			

			end
			

			

			
			
			Base
			.
			
			@
			propagate_inbounds
			 
			
			function
			 
			
			
			Base
			.
			
			getindex
			(
			
			A
			::

	
			BatchView
			,
			 
			
			is
			::
			AbstractVector
			)
			
			
    
			
			obsindices
			 
			=
			 
			
			
			union
			(
			
			(
			
			
			_batchrange
			(
			A
			,
			 
			i
			)
			 
			for
			
			 
			i
			 
			in
			 
			is
			)
			...
			)
			::
			
			Vector
			{
			Int
			}
			
    
			
			_getbatch
			(
			A
			,
			 
			obsindices
			)
			

			end
			

			

			
			function
			 
			
			
			_getbatch
			(
			
			A
			::
			

	
			BatchView
			{
			TElem
			,
			 
			TData
			,
			 
			
			Val
			{
			true
			}
			}
			,
			 
			obsindices
			)
			 
			where
			 
			{
			TElem
			,
			 
			TData
			}
			
			
    
			

	
			batch
			(
			
			[
			
			

	
			getobs
			(
			
			A
			.
			
			data
			,
			 
			i
			)
			 
			for
			
			 
			i
			 
			in
			 
			obsindices
			]
			)
			

			end
			

			
			function
			 
			
			
			_getbatch
			(
			
			A
			::
			

	
			BatchView
			{
			TElem
			,
			 
			TData
			,
			 
			
			Val
			{
			false
			}
			}
			,
			 
			obsindices
			)
			 
			where
			 
			{
			TElem
			,
			 
			TData
			}
			
			
    
			
			return
			 
			
			[
			
			

	
			getobs
			(
			
			A
			.
			
			data
			,
			 
			i
			)
			 
			for
			
			 
			i
			 
			in
			 
			obsindices
			]
			

			end
			

			
			function
			 
			
			
			_getbatch
			(
			
			A
			::
			

	
			BatchView
			{
			TElem
			,
			 
			TData
			,
			 
			
			Val
			{
			nothing
			}
			}
			,
			 
			obsindices
			)
			 
			where
			 
			{
			TElem
			,
			 
			TData
			}
			
			
    
			

	
			getobs
			(
			
			A
			.
			
			data
			,
			 
			obsindices
			)
			

			end
			

			

			
			
			
			Base
			.
			
			parent
			(
			
			A
			::

	
			BatchView
			)
			 
			=
			 
			
			A
			.
			
			data
			

			
			
			
			
			Base
			.
			
			eltype
			(
			
			::
			

	
			BatchView
			{
			Tel
			}
			)
			 
			where
			 
			Tel
			 
			=
			 
			Tel

override AbstractDataContainer default


			
			
			
			
			
			Base
			.
			
			iterate
			(
			
			A
			::

	
			BatchView
			,
			 
			
			state
			 
			=
			 
			1
			)
			 
			=
			
			
    
			(
			
			state
			 
			>
			 
			

	
			numobs
			(
			A
			)
			)
			 
			?
			 
			nothing
			 
			:
			 
			
			(
			
			A
			[
			state
			]
			,
			 
			
			state
			 
			+
			 
			1
			)

Helper function to translate a batch-index into a range of observations.


			
			
			
			@
			inline
			 
			
			function
			 
			
			_batchrange
			(
			
			A
			::

	
			BatchView
			,
			 
			

	
			batchindex
			::
			Int
			)
			
			
    
			
			@
			boundscheck
			
			 
			(
			
			

	
			batchindex
			 
			>
			 
			
			A
			.
			
			count
			 
			||
			
			 

	
			batchindex
			 
			<
			 
			0
			)
			 
			&&
			 
			
			throw
			(
			
			BoundsError
			(
			)
			)
			
    
			
			startidx
			 
			=
			
			
			 
			(
			

	
			batchindex
			 
			-
			 
			1
			)
			 
			*
			 
			
			A
			.
			

	
			batchsize
			 
			+
			 
			1
			
    
			
			endidx
			 
			=
			 
			
			min
			(
			

	
			numobs
			(
			
			parent
			(
			A
			)
			)
			,
			 
			
			
			startidx
			 
			+
			 
			
			A
			.
			

	
			batchsize
			 
			-
			1
			)
			
    
			
			return
			
			 
			startidx
			:
			endidx
			

			end
			

			

			
			function
			 
			
			
			Base
			.
			
			showarg
			(
			
			io
			::
			IO
			,
			 
			
			A
			::

	
			BatchView
			,
			 
			toplevel
			)
			
			
    
			
			print
			(
			io
			,
			 
			
			"
			BatchView(
			"
			)
			
    
			
			
			Base
			.
			
			showarg
			(
			io
			,
			 
			
			parent
			(
			A
			)
			,
			 
			false
			)
			
    
			
			print
			(
			io
			,
			 
			
			"
			, 
			"
			)
			
    
			
			print
			(
			io
			,
			 
			
			"
			batchsize=
			$
			(
			
			A
			.
			

	
			batchsize
			)
			, 
			"
			)
			
    
			
			print
			(
			io
			,
			 
			
			"
			partial=
			$
			(
			
			A
			.
			
			partial
			)
			"
			)
			
    
			
			print
			(
			io
			,
			 
			')'
			)
			
    
			
			toplevel
			 
			&&
			 
			
			print
			(
			io
			,
			 
			
			"
			 with eltype 
			"
			,
			 
			
			nameof
			(
			
			eltype
			(
			A
			)
			)
			)
			 
			# simplify
			

			end