sanitycheck.jl

FluxTraining/callbacks/sanitycheck.jl is a source file in module FluxTraining

			
			
			
			
			struct
			 

	
			Check
			
			
    
			checkfn
			
    
			name
			
    
			
			throw_error
			::
			Bool
			
    
			message
			

			end
			

			

			

			
			
			struct
			
			 

	
			SanityCheckException
			 
			<:
			 
			Exception
			
			 
			end
			

			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			check
			::

	
			Check
			)
			 
			=
			 
			
			print
			(
			io
			,
			 
			
			"
			Check(\"
			"
			,
			 
			
			check
			.
			
			name
			,
			 
			
			"
			\")
			"
			)
			

			

			

			
			function
			 
			

	
			runchecks
			(
			checks
			,
			 
			learner
			)
			
			
    
			
			failed
			 
			=
			 
			
			trues
			(
			
			length
			(
			checks
			)
			)
			
    
			
			for
			
			 
			
			(
			i
			,
			 
			check
			)
			 
			in
			 
			
			enumerate
			(
			checks
			)
			
			
        
			
			
			failed
			[
			i
			]
			 
			=
			 
			
			!
			
			
			check
			.
			
			checkfn
			(
			learner
			)
			
    
			end
			

			
    
			
			failedchecks
			 
			=
			 
			
			checks
			[
			failed
			]
			

			
    
			
			if
			 
			
			isempty
			(
			failedchecks
			)
			
			
        
			
			return
			
			
    
			end
			

			
    
			
			println
			(
			
			"
			$
			(
			
			length
			(
			failedchecks
			)
			)
			/
			$
			(
			
			length
			(
			checks
			)
			)
			 sanity checks failed:
			"
			)
			

			
    
			
			for
			
			 
			
			(
			i
			,
			 
			check
			)
			 
			in
			 
			
			enumerate
			(
			failedchecks
			)
			
			
        
			
			println
			(
			
			"
			---
			"
			)
			
        
			
			println
			(
			i
			,
			 
			
			"
			: 
			"
			,
			 
			
			check
			.
			
			name
			,
			 
			
			"
			 (
			"
			,
			 
			
			
			check
			.
			
			throw_error
			 
			?
			 
			
			"
			ERROR
			"
			 
			:
			 
			
			"
			WARNING
			"
			,
			 
			
			"
			)
			"
			)
			
        
			
			println
			(
			)
			
        
			
			println
			(
			
			check
			.
			
			message
			)
			
    
			end
			

			
    
			
			if
			 
			
			any
			(
			
			getfield
			.
			
			(
			failedchecks
			,
			 
			
			:
			throw_error
			)
			)
			
			
        
			
			throw
			(
			

	
			SanityCheckException
			(
			)
			)
			
    
			end
			

			end
			

			

			

			
			
			
			"""
			

			    SanityCheck([checks; usedefault = true])

			

			Callback that runs sanity [`Check`](#)s when the `Learner` is initialized.

			If `usedefault` is `true`, it will run all checks in FluxTraining.CHECKS

			in addition to the ones you pass in.

			"""
			

			
			mutable
			
			 
			struct
			
			 

	
			SanityCheck
			 
			<:
			 

	
			Callback
			
			
    
			
			checks
			::
			
			Vector
			{

	
			Check
			}
			
    
			
			checked
			::
			Bool
			
    
			
			function
			 
			

	
			SanityCheck
			(
			
			checks
			 
			=
			 
			
			[
			]
			
			;
			 
			
			usedefault
			 
			=
			 
			
			isempty
			(
			checks
			)
			)
			
			
        
			
			if
			 
			usedefault
			
			
            
			
			checks
			 
			=
			 
			
			vcat
			(
			checks
			,
			 

	
			CHECKS
			)
			
        
			end
			
        
			
			return
			 
			
			new
			(
			checks
			,
			 
			false
			)
			
    
			end
			

			end
			

			

			

			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			cb
			::

	
			SanityCheck
			)
			 
			=
			 
			
			print
			(
			io
			,
			
    
			
			"
			SanityCheck(
			"
			,
			 
			
			length
			(
			
			cb
			.
			
			checks
			)
			,
			 
			
			"
			 checks)
			"
			)
			

			

			

			

			
			

	
			stateaccess
			(
			
			::

	
			SanityCheck
			)
			 
			=
			 
			
			(
			
    
			
			data
			 
			=
			 
			

	
			Read
			(
			)
			,
			
    
			
			model
			 
			=
			 
			

	
			Read
			(
			)
			,
			
    
			
			lossfn
			 
			=
			 
			

	
			Read
			(
			)
			,
			
    
			
			optimizer
			 
			=
			 
			

	
			Read
			(
			)
			,
			
    
			
			callbacks
			 
			=
			 
			

	
			Read
			(
			)
			,
			

			)
			

			

			

			
			function
			 
			

	
			on
			(
			
			::

	
			EpochBegin
			,
			 
			
			phase
			::

	
			AbstractTrainingPhase
			,
			 
			
			cb
			::

	
			SanityCheck
			,
			 
			learner
			)
			
			
    
			
			if
			 
			
			!
			
			cb
			.
			
			checked
			
			
        
			

	
			runchecks
			(
			
			cb
			.
			
			checks
			,
			 
			learner
			)
			
        
			
			
			cb
			.
			
			checked
			 
			=
			 
			true
			
    
			end
			

			end

Checks


			
			
			
			

	
			CheckDataIteratorTrain
			(
			)
			 
			=
			
    
			
			

	
			Check
			(
			
        
			
			"
			Has training data iterator
			"
			,
			
        
			true
			,
			

			
			"""
			

			`learner` training data iterator is `nothing` (not set). You can pass it to

			`Learner` in one of the following ways:

			

			- `Learner(model, traindataiter, opt, lossfn)`

			- `Learner(model, (traindataiter, valdataiter), opt, lossfn)`

			- `Learner(model, (training = traindataiter,), opt, lossfn)`

			"""
			
        
			)
			 
			do
			
			 
			learner
			
			
        
			
			!
			
			isnothing
			(
			
			
			learner
			.
			
			data
			.
			
			training
			)
			
    
			end
			

			

			
			

	
			CheckDataIteratorValid
			(
			)
			 
			=
			
    
			
			

	
			Check
			(
			
        
			
			"
			Has validation data iterator
			"
			,
			
        
			false
			,
			

			
			"""
			

			`learner` validation data iterator is `nothing` (not set).

			It is not mandatory, but you must set it to fit `ValidationPhase`s.

			

			You can pass it to `Learner` like this:

			

			- `Learner(model, (traindataiter, valdataiter), opt, lossfn)`

			

			Or if you want to use training data as validation data:

			

			- `Learner(model, (traindataiter, traindataiter), opt, lossfn)`

			"""
			
        
			)
			 
			do
			
			 
			learner
			
			
        
			
			!
			
			isnothing
			(
			
			
			learner
			.
			
			data
			.
			
			validation
			)
			
    
			end
			

			

			
			

	
			CheckIteratesTuples
			(
			)
			 
			=
			
    
			
			

	
			Check
			(
			
        
			
			"
			Data iterators iterate over tuples
			"
			,
			
        
			true
			,
			

			
			"""
			

			Data iterators need to be iterable and return tuples.

			This means that `for (x, y) in dataiter end` works where

			`(x, y)` is a pair of model inputs and outputs.

			"""
			
        
			)
			 
			do
			
			 
			learner
			
			
        
			
			try
			
			
            
			
			
			batch
			,
			 
			_
			 
			=
			 
			
			
			Base
			.
			
			iterate
			(
			
			
			learner
			.
			
			data
			.
			
			training
			)
			
            
			
			
			x
			,
			 
			y
			 
			=
			 
			batch
			
        
			catch
			
			
			
            
			
			return
			 
			false
			
        
			
			
			end
			
        
			
			return
			 
			true
			
    
			end
			

			

			
			

	
			CheckModelLossStep
			(
			)
			 
			=
			
    
			
			

	
			Check
			(
			
        
			
			"
			Model and loss function compatible with data
			"
			,
			
        
			true
			,
			

			
			"""
			

			To perform the optimization step, the model and loss function need

			to be compatible with the data. This means the following must work:

			

			- `(x, y), _ = iterate(learner.data.training)`

			- `ŷ = learner.model(x)`

			- `loss = learner.lossfn(ŷ, y)`

			"""
			
        
			)
			 
			do
			
			 
			learner
			
			
        
			
			try
			
			
            
			
			dev
			 
			=
			
			
			 
			

	
			ToGPU
			(
			)
			 
			in
			 
			
			
			learner
			.
			
			callbacks
			.
			
			cbs
			 
			?
			 

	
			gpu
			 
			:
			 
			identity
			
            
			
			
			x
			,
			 
			y
			 
			=
			 
			
			dev
			(
			
			
			iterate
			(
			
			
			learner
			.
			
			data
			.
			
			training
			)
			[
			1
			]
			)
			
            
			
			 
			=
			 
			
			
			dev
			(
			
			learner
			.
			
			model
			)
			(
			x
			)
			
            
			
			@
			assert
			
			 
			
			
			learner
			.
			
			lossfn
			(
			,
			 
			y
			)
			 
			isa
			 
			Number
			
        
			catch
			
			
			
            
			
			return
			 
			false
			
        
			
			
			end
			
        
			
			return
			 
			true
			
    
			end
			

			

			

			
			const
			
			 

	
			CHECKS
			 
			=
			 
			
			[
			
    
			

	
			CheckDataIteratorTrain
			(
			)
			,
			
    
			

	
			CheckDataIteratorValid
			(
			)
			,
			
    
			

	
			CheckIteratesTuples
			(
			)
			,
			
    
			

	
			CheckModelLossStep
			(
			)
			,
			

			]