FluxTraining
			
			
			
			
			struct
			 
	
			Check
			
			
    
			checkfn
			
    
			name
			
    
			
			throw_error
			::
			Bool
			
    
			message
			
			end
			
			
			
			
			
			struct
			
			 
	
			SanityCheckException
			 
			<:
			 
			Exception
			
			 
			end
			
			
			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			check
			::
	
			Check
			)
			 
			=
			 
			
			print
			(
			io
			,
			 
			
			"
			Check(\"
			"
			,
			 
			
			check
			.
			
			name
			,
			 
			
			"
			\")
			"
			)
			
			
			
			
			function
			 
			
	
			runchecks
			(
			checks
			,
			 
			learner
			)
			
			
    
			
			failed
			 
			=
			 
			
			trues
			(
			
			length
			(
			checks
			)
			)
			
    
			
			for
			
			 
			
			(
			i
			,
			 
			check
			)
			 
			in
			 
			
			enumerate
			(
			checks
			)
			
			
        
			
			
			failed
			[
			i
			]
			 
			=
			 
			
			!
			
			
			check
			.
			
			checkfn
			(
			learner
			)
			
    
			end
			
			
    
			
			failedchecks
			 
			=
			 
			
			checks
			[
			failed
			]
			
			
    
			
			if
			 
			
			isempty
			(
			failedchecks
			)
			
			
        
			
			return
			
			
    
			end
			
			
    
			
			println
			(
			
			"
			$
			(
			
			length
			(
			failedchecks
			)
			)
			/
			$
			(
			
			length
			(
			checks
			)
			)
			 sanity checks failed:
			"
			)
			
			
    
			
			for
			
			 
			
			(
			i
			,
			 
			check
			)
			 
			in
			 
			
			enumerate
			(
			failedchecks
			)
			
			
        
			
			println
			(
			
			"
			---
			"
			)
			
        
			
			println
			(
			i
			,
			 
			
			"
			: 
			"
			,
			 
			
			check
			.
			
			name
			,
			 
			
			"
			 (
			"
			,
			 
			
			
			check
			.
			
			throw_error
			 
			?
			 
			
			"
			ERROR
			"
			 
			:
			 
			
			"
			WARNING
			"
			,
			 
			
			"
			)
			"
			)
			
        
			
			println
			(
			)
			
        
			
			println
			(
			
			check
			.
			
			message
			)
			
    
			end
			
			
    
			
			if
			 
			
			any
			(
			
			getfield
			.
			
			(
			failedchecks
			,
			 
			
			:
			throw_error
			)
			)
			
			
        
			
			throw
			(
			
	
			SanityCheckException
			(
			)
			)
			
    
			end
			
			end
			
			
			
			
			
			
			"""
			
			    SanityCheck([checks; usedefault = true])
			
			Callback that runs sanity [`Check`](#)s when the `Learner` is initialized.
			If `usedefault` is `true`, it will run all checks in FluxTraining.CHECKS
			in addition to the ones you pass in.
			"""
			
			
			mutable
			
			 
			struct
			
			 
	
			SanityCheck
			 
			<:
			 
	
			Callback
			
			
    
			
			checks
			::
			
			Vector
			{
	
			Check
			}
			
    
			
			checked
			::
			Bool
			
    
			
			function
			 
			
	
			SanityCheck
			(
			
			checks
			 
			=
			 
			
			[
			]
			
			;
			 
			
			usedefault
			 
			=
			 
			
			isempty
			(
			checks
			)
			)
			
			
        
			
			if
			 
			usedefault
			
			
            
			
			checks
			 
			=
			 
			
			vcat
			(
			checks
			,
			 
	
			CHECKS
			)
			
        
			end
			
        
			
			return
			 
			
			new
			(
			checks
			,
			 
			false
			)
			
    
			end
			
			end
			
			
			
			
			
			
			Base
			.
			
			show
			(
			
			io
			::
			IO
			,
			 
			
			cb
			::
	
			SanityCheck
			)
			 
			=
			 
			
			print
			(
			io
			,
			
    
			
			"
			SanityCheck(
			"
			,
			 
			
			length
			(
			
			cb
			.
			
			checks
			)
			,
			 
			
			"
			 checks)
			"
			)
			
			
			
			
			
			
	
			stateaccess
			(
			
			::
	
			SanityCheck
			)
			 
			=
			 
			
			(
			
    
			
			data
			 
			=
			 
			
	
			Read
			(
			)
			,
			
    
			
			model
			 
			=
			 
			
	
			Read
			(
			)
			,
			
    
			
			lossfn
			 
			=
			 
			
	
			Read
			(
			)
			,
			
    
			
			optimizer
			 
			=
			 
			
	
			Read
			(
			)
			,
			
    
			
			callbacks
			 
			=
			 
			
	
			Read
			(
			)
			,
			
			)
			
			
			
			
			function
			 
			
	
			on
			(
			
			::
	
			EpochBegin
			,
			 
			
			phase
			::
	
			AbstractTrainingPhase
			,
			 
			
			cb
			::
	
			SanityCheck
			,
			 
			learner
			)
			
			
    
			
			if
			 
			
			!
			
			cb
			.
			
			checked
			
			
        
			
	
			runchecks
			(
			
			cb
			.
			
			checks
			,
			 
			learner
			)
			
        
			
			
			cb
			.
			
			checked
			 
			=
			 
			true
			
    
			end
			
			endChecks
			
			
			
			
	
			CheckDataIteratorTrain
			(
			)
			 
			=
			
    
			
			
	
			Check
			(
			
        
			
			"
			Has training data iterator
			"
			,
			
        
			true
			,
			
			
			"""
			
			`learner` training data iterator is `nothing` (not set). You can pass it to
			`Learner` in one of the following ways:
			
			- `Learner(model, traindataiter, opt, lossfn)`
			- `Learner(model, (traindataiter, valdataiter), opt, lossfn)`
			- `Learner(model, (training = traindataiter,), opt, lossfn)`
			"""
			
        
			)
			 
			do
			
			 
			learner
			
			
        
			
			!
			
			isnothing
			(
			
			
			learner
			.
			
			data
			.
			
			training
			)
			
    
			end
			
			
			
			
	
			CheckDataIteratorValid
			(
			)
			 
			=
			
    
			
			
	
			Check
			(
			
        
			
			"
			Has validation data iterator
			"
			,
			
        
			false
			,
			
			
			"""
			
			`learner` validation data iterator is `nothing` (not set).
			It is not mandatory, but you must set it to fit `ValidationPhase`s.
			
			You can pass it to `Learner` like this:
			
			- `Learner(model, (traindataiter, valdataiter), opt, lossfn)`
			
			Or if you want to use training data as validation data:
			
			- `Learner(model, (traindataiter, traindataiter), opt, lossfn)`
			"""
			
        
			)
			 
			do
			
			 
			learner
			
			
        
			
			!
			
			isnothing
			(
			
			
			learner
			.
			
			data
			.
			
			validation
			)
			
    
			end
			
			
			
			
	
			CheckIteratesTuples
			(
			)
			 
			=
			
    
			
			
	
			Check
			(
			
        
			
			"
			Data iterators iterate over tuples
			"
			,
			
        
			true
			,
			
			
			"""
			
			Data iterators need to be iterable and return tuples.
			This means that `for (x, y) in dataiter end` works where
			`(x, y)` is a pair of model inputs and outputs.
			"""
			
        
			)
			 
			do
			
			 
			learner
			
			
        
			
			try
			
			
            
			
			
			batch
			,
			 
			_
			 
			=
			 
			
			
			Base
			.
			
			iterate
			(
			
			
			learner
			.
			
			data
			.
			
			training
			)
			
            
			
			
			x
			,
			 
			y
			 
			=
			 
			batch
			
        
			catch
			
			
			
            
			
			return
			 
			false
			
        
			
			
			end
			
        
			
			return
			 
			true
			
    
			end
			
			
			
			
	
			CheckModelLossStep
			(
			)
			 
			=
			
    
			
			
	
			Check
			(
			
        
			
			"
			Model and loss function compatible with data
			"
			,
			
        
			true
			,
			
			
			"""
			
			To perform the optimization step, the model and loss function need
			to be compatible with the data. This means the following must work:
			
			- `(x, y), _ = iterate(learner.data.training)`
			- `ŷ = learner.model(x)`
			- `loss = learner.lossfn(ŷ, y)`
			"""
			
        
			)
			 
			do
			
			 
			learner
			
			
        
			
			try
			
			
            
			
			dev
			 
			=
			
			
			 
			
	
			ToGPU
			(
			)
			 
			in
			 
			
			
			learner
			.
			
			callbacks
			.
			
			cbs
			 
			?
			 
	
			gpu
			 
			:
			 
			identity
			
            
			
			
			x
			,
			 
			y
			 
			=
			 
			
			dev
			(
			
			
			iterate
			(
			
			
			learner
			.
			
			data
			.
			
			training
			)
			[
			1
			]
			)
			
            
			
			ŷ
			 
			=
			 
			
			
			dev
			(
			
			learner
			.
			
			model
			)
			(
			x
			)
			
            
			
			@
			assert
			
			 
			
			
			learner
			.
			
			lossfn
			(
			ŷ
			,
			 
			y
			)
			 
			isa
			 
			Number
			
        
			catch
			
			
			
            
			
			return
			 
			false
			
        
			
			
			end
			
        
			
			return
			 
			true
			
    
			end
			
			
			
			
			const
			
			 
	
			CHECKS
			 
			=
			 
			
			[
			
    
			
	
			CheckDataIteratorTrain
			(
			)
			,
			
    
			
	
			CheckDataIteratorValid
			(
			)
			,
			
    
			
	
			CheckIteratesTuples
			(
			)
			,
			
    
			
	
			CheckModelLossStep
			(
			)
			,
			
			]