testutils.jl

FluxTraining/testutils.jl is a source file in module FluxTraining

This file defines testlearner that constructs a Learner with a very simple optimisation problem. This should be used in tests for callbacks.


			
			
			
			
			struct
			 

	
			TestModel
			
			
    
			coeff
			

			end
			

			
			

	
			TestModel
			(
			
			coeff
			::
			Number
			)
			 
			=
			 
			

	
			TestModel
			(
			
			[
			coeff
			]
			)
			

			
			
			

	
			Flux
			.
			
			trainable
			(
			
			m
			::

	
			TestModel
			)
			 
			=
			 
			
			(
			
			m
			.
			
			coeff
			,
			)
			

			
			
			(
			
			m
			::

	
			TestModel
			)
			(
			x
			)
			 
			=
			
			 
			x
			 
			.*
			 
			
			m
			.
			
			coeff
			

			
			

	
			Flux
			.
			
			@
			functor
			 

	
			TestModel
			

			

			

			
			function
			 
			

	
			testbatch
			(
			batchsize
			,
			 
			coeff
			)
			
			
    
			
			xs
			 
			=
			 
			
			rand
			(
			
			1.
			:
			100.
			,
			 
			batchsize
			)
			
    
			
			return
			 
			
			(
			xs
			,
			 
			
			xs
			 
			.*
			 
			coeff
			)
			

			end
			

			

			

			
			function
			 
			

	
			testbatches
			(
			
			n
			::
			Int
			,
			 
			coeff
			,
			 
			
			batchsize
			 
			=
			 
			8
			)
			
			
    
			(
			
			

	
			testbatch
			(
			batchsize
			,
			 
			coeff
			)
			 
			for
			
			 
			_
			 
			
			 
			1
			:
			n
			)
			

			end
			

			

			

			
			
			
			"""
			

			    testlearner(callbacks...[; opt, nbatches, coeff, batchsize, kwargs...])

			

			Construct a [`Learner`](#) with a simple optimization problem. This

			learner should be used in tests that require training a model, e.g.

			for callbacks.

			"""
			

			
			function
			 
			

	
			testlearner
			(
			
        
			
			args
			...
			
			;
			
        
			
			opt
			 
			=
			 
			
			Descent
			(
			0.001
			)
			,
			
        
			
			nbatches
			 
			=
			 
			16
			,
			
        
			
			coeff
			 
			=
			 
			3
			,
			
        
			
			batchsize
			 
			=
			 
			8
			,
			
        
			
			usedefaultcallbacks
			 
			=
			 
			false
			,
			
        
			
			kwargs
			...
			)
			
			
    
			
			model
			 
			=
			 
			

	
			TestModel
			(
			
			rand
			(
			)
			)
			
    
			
			data
			 
			=
			 
			
			collect
			(
			

	
			testbatches
			(
			nbatches
			,
			 
			coeff
			,
			 
			batchsize
			)
			)
			
    
			

	
			Learner
			(
			
        
			model
			,
			
        
			
			(
			data
			,
			 
			data
			)
			,
			
        
			opt
			,
			
        
			

	
			Flux
			.
			
			mae
			,
			
        
			
			args
			...
			
			;
			
        
			
			usedefaultcallbacks
			 
			=
			 
			usedefaultcallbacks
			,
			
        
			
			kwargs
			...
			)
			

			end