graph.jl

FluxTraining/callbacks/graph.jl is a source file in module FluxTraining

			

TOOD: check for reads on :cbstate without a previous write (missing callback) TODO: check for cyclical dependencies


			
			
			
			
			
			"""
			

			    callbackgraph(callbacks) -> SimpleDiGraph

			

			Creates a directed acyclic graph from a list of `callbacks`.

			Ordering is given through `runafter` and `resolveconflict`.

			

			If a write conflict cannot be resolved (i.e. `resolveconflict`)

			is not implemented), throws an error.

			"""
			

			
			function
			 
			

	
			callbackgraph
			(
			callbacks
			)
			
			
    
			# create a graph
			
    
			
			g
			 
			=
			 
			
			SimpleDiGraph
			(
			
			length
			(
			callbacks
			)
			)
			

			
    
			# add dependencies defined through `runafter`
			
    
			
			foreach
			(
			
			e
			 
			->
			 
			
			add_edge!
			(
			g
			,
			 
			e
			)
			,
			 
			

	
			edgesrunafter
			(
			callbacks
			)
			)
			

			

			
    
			
			permissions
			 
			=
			 
			

	
			stateaccess
			.
			
			(
			callbacks
			)
			
    
			
			writeaccesses
			 
			=
			 
			
			[
			
			

	
			accesses
			(
			ps
			,
			 

	
			Write
			)
			 
			for
			
			 
			ps
			 
			in
			 
			permissions
			]
			
    
			
			readaccesses
			 
			=
			 
			
			[
			
			

	
			accesses
			(
			ps
			,
			 

	
			Read
			)
			 
			for
			
			 
			ps
			 
			in
			 
			permissions
			]
			

			
    
			# detect write-write conflicts and handle them
			
    
			
			for
			
			 
			
			(
			i
			,
			 
			j
			,
			 
			accessi
			,
			 
			accessj
			)
			 
			in
			 
			

	
			findconflicts
			(
			writeaccesses
			,
			 
			writeaccesses
			)
			
			
        
			
			if
			 
			
			!
			(
			
			
			has_edge
			(
			g
			,
			 
			i
			,
			 
			j
			)
			 
			||
			 
			
			has_edge
			(
			g
			,
			 
			j
			,
			 
			i
			)
			)
			
			
            
			
			
			cb1
			,
			 
			cb2
			 
			=
			
			 
			
			callbacks
			[
			i
			]
			,
			 
			
			callbacks
			[
			j
			]
			
            
			
			resolution
			 
			=
			 
			

	
			resolveconflict
			(
			cb1
			,
			 
			cb2
			)
			
            
			
			if
			
			 
			resolution
			 
			isa
			 

	
			NotDefined
			
			
                
			

	
			errorwriteconflict
			(
			cb1
			,
			 
			cb2
			,
			 
			accessi
			,
			 
			accessj
			,
			 
			
			resolvable
			 
			=
			 
			true
			)
			
            
			end
			
            
			
			if
			
			 
			resolution
			 
			isa
			 

	
			Unresolvable
			
			
                
			

	
			errorwriteconflict
			(
			cb1
			,
			 
			cb2
			,
			 
			accessi
			,
			 
			accessj
			,
			 
			
			resolvable
			 
			=
			 
			false
			)
			
            
			
			elseif
			
			 
			resolution
			 
			isa
			 

	
			RunFirst
			
			
                
			
			if
			
			 
			
			resolution
			.
			
			cb
			 
			===
			 
			cb1
			
			
                    
			
			add_edge!
			(
			g
			,
			 
			cb1
			,
			 
			cb2
			)
			
                
			else
			
			
                    
			
			add_edge!
			(
			g
			,
			 
			cb2
			,
			 
			cb1
			)
			
                
			end
			
            
			# resolution isa NoConflict
			
            
			else
			
			
                
			continue
			
            
			end
			
        
			end
			
    
			end
			

			
    
			# detect read-write conflicts and handle them
			
    
			# writes will be places before reads!
			
    
			
			for
			
			 
			
			(
			i
			,
			 
			j
			)
			 
			in
			 
			

	
			findconflicts
			(
			writeaccesses
			,
			 
			readaccesses
			)
			
			
        
			
			resolution
			 
			=
			 
			

	
			resolveconflict
			(
			
			callbacks
			[
			i
			]
			,
			 
			
			callbacks
			[
			j
			]
			)
			
        
			
			if
			
			
			 
			resolution
			 
			isa
			 

	
			RunFirst
			 
			&&
			
			 
			
			resolution
			.
			
			cb
			 
			==
			 
			
			callbacks
			[
			j
			]
			
			
            
			
			add_edge!
			(
			g
			,
			 
			j
			,
			 
			i
			)
			
        
			else
			
			
            
			
			add_edge!
			(
			g
			,
			 
			i
			,
			 
			j
			)
			
        
			end
			
    
			end
			

			
    
			# TODO: check if callback state is read without being written
			

			
    
			# TODO: check for cyclical dependencies
			
    
			
			return
			 
			g
			

			end
			

			

			
			function
			 
			

	
			findconflicts
			(
			accesses1
			,
			 
			accesses2
			)
			
			
    
			
			conflicts
			 
			=
			 
			
			[
			]
			
    
			
			for
			
			 
			
			(
			i
			,
			 
			j
			)
			 
			in
			 
			
			
			Iterators
			.
			
			product
			(
			
			1
			:
			
			length
			(
			accesses1
			)
			,
			 
			
			1
			:
			
			length
			(
			accesses2
			)
			)
			
			
        
			
			if
			
			 
			i
			 
			!=
			 
			j
			
			
            
			
			for
			
			 
			
			(
			a1
			,
			 
			a2
			)
			 
			in
			 
			
			
			Iterators
			.
			
			product
			(
			
			accesses1
			[
			i
			]
			,
			 
			
			accesses2
			[
			j
			]
			)
			
			
                
			
			if
			 
			

	
			hasconflict
			(
			a1
			,
			 
			a2
			)
			
			
                    
			
			push!
			(
			conflicts
			,
			 
			
			(
			i
			,
			 
			j
			,
			 
			a1
			,
			 
			a2
			)
			)
			
                
			end
			
            
			end
			
        
			end
			
    
			end
			
    
			
			return
			 
			conflicts
			

			end
			

			

			

			
			
			
			"""
			

			    edgesrunafter(callbacks)

			

			Return a vector of `Edge`s representing dependencies

			defined by `runafter`.

			"""
			

			
			function
			 
			

	
			edgesrunafter
			(
			callbacks
			)
			
			
    
			
			edges
			 
			=
			 
			
			
			Edge
			{
			Int
			}
			[
			]
			
    
			
			for
			
			 
			
			(
			i
			,
			 
			cb
			)
			 
			in
			 
			
			enumerate
			(
			callbacks
			)
			
			
        
			
			Ts
			 
			=
			 
			

	
			runafter
			(
			cb
			)
			
        
			
			for
			
			 
			
			(
			j
			,
			 
			othercb
			)
			 
			in
			 
			
			enumerate
			(
			callbacks
			)
			
			
            
			
			if
			 
			
			any
			(
			
			[
			
			
			othercb
			 
			isa
			 
			T
			 
			for
			
			 
			T
			 
			in
			  
			Ts
			]
			)
			
			
                
			
			push!
			(
			edges
			,
			 
			
			Edge
			(
			j
			,
			 
			i
			)
			)
			
            
			end
			
        
			end
			
    
			end
			
    
			
			return
			 
			edges
			

			end
			

			

			

			

			

			
			
			
			"""
			

			    accesses()

			

			Enumerate all valid state accesses of permissions of kind `perm`.

			

			`accesses((x = Read(),), Read()) === [(:x,)]`

			`accesses((x = Read(),), Write()) === []`

			

			"""
			

			
			function
			 
			

	
			accesses
			(
			
			permissions
			::
			NamedTuple
			,
			 
			
			
			P
			::
			
			Type
			{
			
			<:

	
			Permission
			}
			 
			=
			 

	
			Permission
			,
			 
			
			a
			 
			=
			 
			
			[
			]
			,
			 
			
			prefix
			 
			=
			 
			
			(
			)
			)
			
			
    
			
			for
			
			 
			
			(
			field
			,
			 
			perm
			)
			 
			in
			 
			
			zip
			(
			
			keys
			(
			permissions
			)
			,
			 
			permissions
			)
			
			
        
			
			if
			
			 
			perm
			 
			isa
			 
			P
			
			
            
			
			push!
			(
			a
			,
			 
			
			(
			
			prefix
			...
			,
			 
			field
			)
			)
			
        
			
			elseif
			
			 
			perm
			 
			isa
			 
			NamedTuple
			
			
            
			

	
			accesses
			(
			perm
			,
			 
			P
			,
			 
			a
			,
			 
			
			(
			
			prefix
			...
			,
			 
			field
			)
			)
			
        
			end
			
    
			end
			
    
			a
			

			end
			

			

			

			
			function
			 
			
			
			

	
			hasconflict
			(
			
			access1
			::
			T1
			,
			 
			
			access2
			::
			T2
			)
			::
			Bool
			 
			where
			 
			{
			
			T1
			<:
			Tuple
			,
			 
			
			T2
			<:
			Tuple
			}
			
			
    
			
			
			l1
			,
			 
			l2
			 
			=
			
			 
			
			length
			(
			access1
			)
			,
			 
			
			length
			(
			access2
			)
			
    
			
			
			l1
			 
			>
			 
			0
			 
			||
			 
			
			return
			 
			false
			
    
			
			
			l2
			 
			>
			 
			0
			 
			||
			 
			
			return
			 
			false
			
    
			
			if
			
			 
			
			access1
			[
			1
			]
			 
			!=
			 
			
			access2
			[
			1
			]
			
			
        
			
			return
			 
			false
			
    
			
			elseif
			 
			(
			
			
			l1
			 
			==
			 
			1
			 
			||
			
			 
			l2
			 
			==
			 
			1
			)
			
			
        
			
			return
			 
			true
			
    
			else
			
			
        
			
			return
			 
			

	
			hasconflict
			(
			
			access1
			[
			
			2
			:
			end
			]
			,
			 
			
			access2
			[
			
			2
			:
			end
			]
			)
			
    
			end
			

			end
			

			

			

			

			
			function
			 
			

	
			errorwriteconflict
			(
			cb1
			,
			 
			cb2
			,
			 
			access1
			,
			 
			access2
			
			;
			 
			
			resolvable
			 
			=
			 
			true
			)
			
			
    
			
			msg
			 
			=
			 
			
			"""
			

			    
			Write conflict detected between 
			$
			cb1
			 and 
			$
			(
			cb2
			)
			!

			

			    
			- 
			$
			cb1
			 writes to 
			$
			(
			

	
			formataccess
			(
			access1
			)
			)
			

			    
			- 
			$
			cb2
			 writes to 
			$
			(
			

	
			formataccess
			(
			access2
			)
			)
			

			    
			"""
			

			
    
			
			if
			 
			resolvable
			
			
        
			
			msg
			 
			*=
			 
			
			"""
			

			

			        
			To resolve this, implement:

			

			        
			`resolveconflict(::
			$
			(
			
			typeof
			(
			cb1
			)
			)
			, ::
			$
			(
			
			typeof
			(
			cb2
			)
			)
			)`

			

			        
			See also `resolveconflict` and `ConflictResolution`.

			        
			"""
			
    
			else
			
			
        
			
			msg
			 
			*=
			 
			
			"""
			

			

			        
			Both callbacks cannot be used together, please remove one.

			        
			"""
			
    
			end
			

			
    
			
			error
			(
			msg
			)
			

			end
			

			

			
			

	
			formataccess
			(
			access
			)
			 
			=
			 
			
			join
			(
			
			string
			.
			
			(
			access
			)
			,
			 
			'.'
			)
			

			

			

			
			
			
			"""
			

			    iterpairs(a)

			

			Iterators over the Cartesian product of `a` with itself,

			skipping any pairs `(a, b)` where `a == b`.

			"""
			

			
			

	
			iterpairs
			(
			
			a
			::
			AbstractArray
			)
			 
			=
			 
			
			filter
			(
			
			
			(
			
			(
			i
			,
			 
			j
			)
			,
			)
			 
			->
			
			 
			i
			 
			!=
			 
			j
			,
			 
			
			
			
			Iterators
			.
			
			product
			(
			
			1
			:
			9
			,
			 
			
			1
			:
			9
			)
			 
			|>
			 
			collect
			)