sanitycheck.jl

YVIq2/src/callbacks/ is a source file in module YVIq2

			
			
			
			
			
			struct
			
			 
			
	
		
			Check
			
			
			
    
			
			checkfn
			
			
    
			
			name
			
			
    
			
			
			throw_error
			
			::
			
			Bool
			
			
    
			
			message
			
			

			
			end
			
			

			
			

			
			

			
			
			
			
			struct
			
			
			 
			
	
		
			SanityCheckException
			
			 
			
			<:
			
			 
			
			Exception
			
			
			 
			
			end
			
			

			
			

			
			
			
			
			Base
			
			.
			
			
			show
			
			(
			
			
			io
			
			::
			
			IO
			
			,
			
			 
			
			
			check
			
			::
			
	
		
			Check
			
			)
			
			 
			
			=
			
			 
			
			
			print
			
			(
			
			io
			
			,
			
			 
			
			"
			
			Check(\"
			
			"
			
			,
			
			 
			
			
			check
			
			.
			
			
			name
			
			,
			
			 
			
			"
			
			\")
			
			"
			
			)
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			runchecks
			
			(
			
			checks
			
			,
			
			 
			
			learner
			
			)
			
			
			
    
			
			
			failed
			
			 
			
			=
			
			 
			
			
			trues
			
			(
			
			
			length
			
			(
			
			checks
			
			)
			
			)
			
			
    
			
			
			for
			
			
			 
			
			
			(
			
			i
			
			,
			
			 
			
			check
			
			)
			
			 
			
			in
			
			 
			
			
			enumerate
			
			(
			
			checks
			
			)
			
			
			
        
			
			
			
			failed
			
			[
			
			i
			
			]
			
			 
			
			=
			
			 
			
			
			!
			
			
			
			check
			
			.
			
			
			checkfn
			
			(
			
			learner
			
			)
			
			
    
			
			end
			
			

			
			
    
			
			
			failedchecks
			
			 
			
			=
			
			 
			
			
			checks
			
			[
			
			failed
			
			]
			
			

			
			
    
			
			
			if
			
			 
			
			
			isempty
			
			(
			
			failedchecks
			
			)
			
			
			
        
			
			
			return
			
			
			
			
    
			
			end
			
			

			
			
    
			
			
			println
			
			(
			
			
			"
			
			$
			
			(
			
			
			length
			
			(
			
			failedchecks
			
			)
			
			)
			
			/
			
			$
			
			(
			
			
			length
			
			(
			
			checks
			
			)
			
			)
			
			 sanity checks failed:
			
			"
			
			)
			
			

			
			
    
			
			
			for
			
			
			 
			
			
			(
			
			i
			
			,
			
			 
			
			check
			
			)
			
			 
			
			in
			
			 
			
			
			enumerate
			
			(
			
			failedchecks
			
			)
			
			
			
        
			
			
			println
			
			(
			
			"
			
			---
			
			"
			
			)
			
			
        
			
			
			println
			
			(
			
			i
			
			,
			
			 
			
			"
			
			: 
			
			"
			
			,
			
			 
			
			
			check
			
			.
			
			
			name
			
			,
			
			 
			
			"
			
			 (
			
			"
			
			,
			
			 
			
			
			
			check
			
			.
			
			
			throw_error
			
			 
			
			?
			
			 
			
			"
			
			ERROR
			
			"
			
			 
			
			:
			
			 
			
			"
			
			WARNING
			
			"
			
			,
			
			 
			
			"
			
			)
			
			"
			
			)
			
			
        
			
			
			println
			
			(
			
			)
			
			
        
			
			
			println
			
			(
			
			
			check
			
			.
			
			
			message
			
			)
			
			
    
			
			end
			
			

			
			
    
			
			
			if
			
			 
			
			
			any
			
			(
			
			
			getfield
			
			.
			
			
			(
			
			failedchecks
			
			,
			
			 
			
			
			:
			
			throw_error
			
			)
			
			)
			
			
			
        
			
			
			throw
			
			(
			
			
	
		
			SanityCheckException
			
			(
			
			)
			
			)
			
			
    
			
			end
			
			

			
			end
			
			

			
			

			
			

			
			
			mutable
			
			
			
			 
			
			struct
			
			
			 
			
	
		
			SanityCheck
			
			 
			
			<:
			
			 
			
	
		
			Callback
			
			
			
    
			
			
			checks
			
			::
			
			
			Vector
			
			{
			
	
		
			Check
			
			}
			
			
    
			
			
			checked
			
			::
			
			Bool
			
			
    
			
			
			function
			
			 
			
			
	
		
			SanityCheck
			
			(
			
			
			checks
			
			 
			
			=
			
			 
			
			
			[
			
			]
			
			
			;
			
			 
			
			
			usedefault
			
			 
			
			=
			
			 
			
			
			isempty
			
			(
			
			checks
			
			)
			
			)
			
			
			
        
			
			
			if
			
			 
			
			usedefault
			
			
			
            
			
			
			checks
			
			 
			
			=
			
			 
			
			
			vcat
			
			(
			
			checks
			
			,
			
			 
			
	
		
			CHECKS
			
			)
			
			
        
			
			end
			
			
        
			
			
			return
			
			 
			
			
			new
			
			(
			
			checks
			
			,
			
			 
			
			false
			
			)
			
			
    
			
			end
			
			

			
			end
			
			

			
			

			
			

			
			
			
			
			Base
			
			.
			
			
			show
			
			(
			
			
			io
			
			::
			
			IO
			
			,
			
			 
			
			
			cb
			
			::
			
	
		
			SanityCheck
			
			)
			
			 
			
			=
			
			 
			
			
			print
			
			(
			
			io
			
			,
			
			
    
			
			"
			
			SanityCheck(
			
			"
			
			,
			
			 
			
			
			length
			
			(
			
			
			cb
			
			.
			
			
			checks
			
			)
			
			,
			
			 
			
			"
			
			 checks)
			
			"
			
			)
			
			

			
			

			
			

			
			

			
			
			
	
		
			stateaccess
			
			(
			
			
			::
			
	
		
			SanityCheck
			
			)
			
			 
			
			=
			
			 
			
			
			(
			
			
    
			
			
			data
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			,
			
			
    
			
			
			model
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			,
			
			
    
			
			
			lossfn
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			,
			
			
    
			
			
			optimizer
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			,
			
			
    
			
			
			callbacks
			
			 
			
			=
			
			 
			
			
	
		
			Read
			
			(
			
			)
			
			,
			
			

			
			)
			
			

			
			

			
			

			
			
			function
			
			 
			
			
	
		
			on
			
			(
			
			
			::
			
	
		
			EpochBegin
			
			,
			
			 
			
			
			phase
			
			::
			
	
		
			AbstractTrainingPhase
			
			,
			
			 
			
			
			cb
			
			::
			
	
		
			SanityCheck
			
			,
			
			 
			
			learner
			
			)
			
			
			
    
			
			
			if
			
			 
			
			
			!
			
			
			cb
			
			.
			
			
			checked
			
			
			
        
			
			
	
		
			runchecks
			
			(
			
			
			cb
			
			.
			
			
			checks
			
			,
			
			 
			
			learner
			
			)
			
			
        
			
			
			
			cb
			
			.
			
			
			checked
			
			 
			
			=
			
			 
			
			true
			
			
    
			
			end
			
			

			
			end

Checks


			
			
			
			
	
		
			CheckDataIteratorTrain
			
			(
			
			)
			
			 
			
			=
			
			
    
			
			
			
	
		
			Check
			
			(
			
			
        
			
			"
			
			Has training data iterator
			
			"
			
			,
			
			
        
			
			true
			
			,
			
			

			
			
			"""
			
			

			
			`learner` training data iterator is `nothing` (not set). You can pass it to

			
			`Learner` in one of the following ways:

			
			

			
			- `Learner(model, traindataiter, opt, lossfn)`

			
			- `Learner(model, (traindataiter, valdataiter), opt, lossfn)`

			
			- `Learner(model, (training = traindataiter,), opt, lossfn)`

			
			"""
			
			
        
			
			)
			
			 
			
			do
			
			
			
			 
			
			learner
			
			
			
        
			
			
			!
			
			
			isnothing
			
			(
			
			
			
			learner
			
			.
			
			
			data
			
			.
			
			
			training
			
			)
			
			
    
			
			end
			
			

			
			

			
			
			
	
		
			CheckDataIteratorValid
			
			(
			
			)
			
			 
			
			=
			
			
    
			
			
			
	
		
			Check
			
			(
			
			
        
			
			"
			
			Has validation data iterator
			
			"
			
			,
			
			
        
			
			false
			
			,
			
			

			
			
			"""
			
			

			
			`learner` validation data iterator is `nothing` (not set).

			
			It is not mandatory, but you must set it to fit `ValidationPhase`s.

			
			

			
			You can pass it to `Learner` like this:

			
			

			
			- `Learner(model, (traindataiter, valdataiter), opt, lossfn)`

			
			

			
			Or if you want to use training data as validation data:

			
			

			
			- `Learner(model, (traindataiter, traindataiter), opt, lossfn)`

			
			"""
			
			
        
			
			)
			
			 
			
			do
			
			
			
			 
			
			learner
			
			
			
        
			
			
			!
			
			
			isnothing
			
			(
			
			
			
			learner
			
			.
			
			
			data
			
			.
			
			
			validation
			
			)
			
			
    
			
			end
			
			

			
			

			
			
			
	
		
			CheckIteratesTuples
			
			(
			
			)
			
			 
			
			=
			
			
    
			
			
			
	
		
			Check
			
			(
			
			
        
			
			"
			
			Data iterators iterate over tuples
			
			"
			
			,
			
			
        
			
			true
			
			,
			
			

			
			
			"""
			
			

			
			Data iterators need to be iterable and return tuples.

			
			This means that `for (x, y) in dataiter end` works where

			
			`(x, y)` is a pair of model inputs and outputs.

			
			"""
			
			
        
			
			)
			
			 
			
			do
			
			
			
			 
			
			learner
			
			
			
        
			
			
			try
			
			
			
            
			
			
			
			batch
			
			,
			
			 
			
			_
			
			 
			
			=
			
			 
			
			
			
			Base
			
			.
			
			
			iterate
			
			(
			
			
			
			learner
			
			.
			
			
			data
			
			.
			
			
			training
			
			)
			
			
            
			
			
			
			x
			
			,
			
			 
			
			y
			
			 
			
			=
			
			 
			
			batch
			
			
        
			
			catch
			
			
			
			
			
            
			
			
			return
			
			 
			
			false
			
			
        
			
			
			
			
			
			end
			
			
        
			
			
			return
			
			 
			
			true
			
			
    
			
			end
			
			

			
			

			
			
			
	
		
			CheckModelLossStep
			
			(
			
			)
			
			 
			
			=
			
			
    
			
			
			
	
		
			Check
			
			(
			
			
        
			
			"
			
			Model and loss function compatible with data
			
			"
			
			,
			
			
        
			
			true
			
			,
			
			

			
			
			"""
			
			

			
			To perform the optimization step, the model and loss function need

			
			to be compatible with the data. This means the following must work:

			
			

			
			- `(x, y), _ = iterate(learner.data.training)`

			
			- `ŷ = learner.model(x)`

			
			- `loss = learner.lossfn(ŷ, y)`

			
			"""
			
			
        
			
			)
			
			 
			
			do
			
			
			
			 
			
			learner
			
			
			
        
			
			
			try
			
			
			
            
			
			
			dev
			
			 
			
			=
			
			
			
			 
			
			
	
		
			ToGPU
			
			(
			
			)
			
			 
			
			in
			
			 
			
			
			
			learner
			
			.
			
			
			callbacks
			
			.
			
			
			cbs
			
			 
			
			?
			
			 
			
			gpu
			
			 
			
			:
			
			 
			
			identity
			
			
            
			
			
			
			x
			
			,
			
			 
			
			y
			
			 
			
			=
			
			 
			
			
			dev
			
			(
			
			
			
			iterate
			
			(
			
			
			
			learner
			
			.
			
			
			data
			
			.
			
			
			training
			
			)
			
			[
			
			1
			
			]
			
			)
			
			
            
			
			
			
			 
			
			=
			
			 
			
			
			
			dev
			
			(
			
			
			learner
			
			.
			
			
			model
			
			)
			
			(
			
			x
			
			)
			
			
            
			
			
			@
			
			assert
			
			
			 
			
			
			
			learner
			
			.
			
			
			lossfn
			
			(
			
			
			,
			
			 
			
			y
			
			)
			
			 
			
			isa
			
			 
			
			Number
			
			
        
			
			catch
			
			
			
			
			
            
			
			
			return
			
			 
			
			false
			
			
        
			
			
			
			
			
			end
			
			
        
			
			
			return
			
			 
			
			true
			
			
    
			
			end
			
			

			
			

			
			

			
			
			const
			
			
			 
			
	
		
			CHECKS
			
			 
			
			=
			
			 
			
			
			[
			
			
    
			
			
	
		
			CheckDataIteratorTrain
			
			(
			
			)
			
			,
			
			
    
			
			
	
		
			CheckDataIteratorValid
			
			(
			
			)
			
			,
			
			
    
			
			
	
		
			CheckIteratesTuples
			
			(
			
			)
			
			,
			
			
    
			
			
	
		
			CheckModelLossStep
			
			(
			
			)
			
			,
			
			

			
			]