splitobs.jl

MLUtils/splitobs.jl is a source file in module MLUtils

			
			
			
			
			
			"""
			

			    splitobs(n::Int; at) -> Tuple

			

			Compute the indices for two or more disjoint subsets of

			the range `1:n` with splits given by `at`.

			

			# Examples

			

			```julia

			julia> splitobs(100, at=0.7)

			(1:70, 71:100)

			

			julia> splitobs(100, at=(0.1, 0.4))

			(1:10, 11:50, 51:100)

			```

			"""
			

			
			

	
			splitobs
			(
			
			n
			::
			Int
			
			;
			 
			at
			)
			 
			=
			 
			
			_splitobs
			(
			n
			,
			 
			at
			)
			

			

			
			
			_splitobs
			(
			
			n
			::
			Int
			,
			 
			
			at
			::
			Integer
			)
			 
			=
			 
			
			_splitobs
			(
			
			n
			::
			Int
			,
			 
			
			at
			 
			/
			 
			n
			)
			 

			
			
			
			_splitobs
			(
			
			n
			::
			Int
			,
			 
			
			at
			::
			
			NTuple
			{
			N
			,
			 
			
			<:
			Integer
			}
			)
			 
			where
			 
			{
			N
			}
			 
			=
			 
			
			_splitobs
			(
			
			n
			::
			Int
			,
			 
			
			at
			 
			./
			 
			n
			)
			 

			

			
			
			_splitobs
			(
			
			n
			::
			Int
			,
			 
			
			at
			::
			
			Tuple
			{
			}
			)
			 
			=
			 
			
			(
			
			1
			:
			n
			,
			)
			

			

			
			function
			 
			
			_splitobs
			(
			
			n
			::
			Int
			,
			 
			
			at
			::
			AbstractFloat
			)
			
			
    
			
			
			0
			 
			<=
			 
			at
			 
			<=
			 
			1
			 
			||
			 
			
			throw
			(
			
			ArgumentError
			(
			
			"
			the parameter \"at\" must be in interval (0, 1)
			"
			)
			)
			
    
			
			n1
			 
			=
			 
			
			clamp
			(
			
			round
			(
			Int
			,
			 
			
			at
			*
			n
			)
			,
			 
			0
			,
			 
			n
			)
			
    
			
			(
			
			1
			:
			n1
			,
			 
			
			
			n1
			+
			1
			:
			n
			)
			

			end
			

			

			
			function
			 
			
			
			_splitobs
			(
			
			n
			::
			Int
			,
			 
			
			at
			::
			
			NTuple
			{
			N
			,
			
			<:
			AbstractFloat
			}
			)
			 
			where
			 
			N
			
			
    
			
			at1
			 
			=
			 
			
			first
			(
			at
			)
			
    
			
			
			a
			,
			 
			b
			 
			=
			 
			
			_splitobs
			(
			
			n
			::
			Int
			,
			 
			at1
			)
			
    
			
			n1
			 
			=
			 
			
			a
			.
			
			stop
			 
    
			
			n2
			 
			=
			 
			
			b
			.
			
			stop
			
    
			
			at2
			 
			=
			
			
			 
			
			
			Base
			.
			
			tail
			(
			at
			)
			 
			.*
			 
			n
			 
			./
			 
			(
			
			n2
			 
			-
			 
			n1
			)
			
    
			
			rest
			 
			=
			 
			
			map
			(
			
			x
			 
			->
			
			 
			n1
			 
			.+
			 
			x
			,
			 
			
			_splitobs
			(
			
			n2
			-
			n1
			,
			 
			at2
			)
			)
			
    
			
			return
			 
			
			(
			a
			,
			 
			
			rest
			...
			)
			

			end
			

			

			
			
			
			"""
			

			    splitobs(data; at, shuffle=false) -> Tuple

			

			Split the `data` into multiple subsets proportional to the

			value(s) of `at`. 

			

			If `shuffle=true`, randomly permute the observations before splitting.

			

			Supports any datatype implementing the [`numobs`](@ref) and

			[`getobs`](@ref) interfaces.

			

			# Examples

			

			```julia

			# A 70%-30% split

			train, test = splitobs(X, at=0.7)

			

			# A 50%-30%-20% split

			train, val, test = splitobs(X, at=(0.5, 0.3))

			

			# A 70%-30% split with multiple arrays and shuffling

			train, test = splitobs((X, y), at=0.7, shuffle=true)

			Xtrain, Ytrain = train

			```

			"""
			

			
			function
			 
			

	
			splitobs
			(
			data
			
			;
			 
			at
			,
			 
			
			
			shuffle
			::
			Bool
			=
			false
			)
			
			
    
			
			if
			 
			shuffle
			
			
        
			
			data
			 
			=
			 
			

	
			shuffleobs
			(
			data
			)
			
    
			end
			
    
			
			n
			 
			=
			 
			

	
			numobs
			(
			data
			)
			
    
			
			return
			 
			
			map
			(
			
			idx
			 
			->
			 
			

	
			obsview
			(
			data
			,
			 
			idx
			)
			,
			 
			

	
			splitobs
			(
			n
			
			;
			 
			at
			)
			)
			

			end