cudnn.jl

Flux/cuda/cudnn.jl is a source file in module Flux

			
			
			
			
			import
			
			 
			NNlibCUDA
			:
			
			 
			batchnorm
			,
			
			 
			∇batchnorm
			

			

			
			function
			 
			
			
			(
			
			BN
			::
			

	
			Flux
			.
			

	
			BatchNorm
			)
			(
			
			x
			::
			
			Union
			{
			
			CuArray
			{
			T
			,
			2
			}
			,
			
			CuArray
			{
			T
			,
			4
			}
			,
			
			CuArray
			{
			T
			,
			5
			}
			}
			,
			
                              
			
			cache
			=
			nothing
			)
			 
			where
			 
			
			T
			<:
			
			Union
			{
			Float32
			,
			 
			Float64
			}
			
			
  
			
  
			
			@
			assert
			 
			
			BN
			.
			
			affine
			 
			
			"
			BatchNorm: only affine=true supported on gpu
			"
			
  
			
			@
			assert
			 
			
			BN
			.
			
			track_stats
			 
			
			"
			BatchNorm: only track_stats=true supported on gpu
			"
			
  
			
			@
			assert
			
			 
			
			length
			(
			
			BN
			.
			
			β
			)
			 
			==
			 
			
			size
			(
			x
			,
			 
			
			
			ndims
			(
			x
			)
			-
			1
			)
			 
			
			"
			BatchNorm: input has wrong number of channels
			"
			
  
			
			return
			 
			
			
			BN
			.
			
			λ
			.
			
			(
			
			batchnorm
			(
			
			BN
			.
			
			γ
			,
			 
			
			BN
			.
			
			β
			,
			 
			x
			,
			 
			
			BN
			.
			
			μ
			,
			 
			
			BN
			.
			
			σ²
			,
			 
			
			BN
			.
			
			momentum
			
			;
			 
                  
			
			cache
			=
			cache
			,
			 
			
			alpha
			=
			1
			,
			 
			
			beta
			=
			0
			,
			 
			
			eps
			=
			
			BN
			.
			
			ϵ
			,
			 
                  
			
			training
			=
			
			

	
			Flux
			.
			
			_isactive
			(
			BN
			,
			 
			x
			)
			)
			)
			

			end
			

			

			
			function
			 
			
			
			ChainRulesCore
			.
			
			rrule
			(
			
			::
			
			typeof
			(
			batchnorm
			)
			,
			 
			g
			,
			 
			b
			,
			 
			x
			,
			 
			running_mean
			,
			 
			running_var
			,
			 
			momentum
			
			;
			 
			
			kw
			...
			)
			
			
  
			
			y
			 
			=
			 
			
			batchnorm
			(
			g
			,
			 
			b
			,
			 
			x
			,
			 
			running_mean
			,
			 
			running_var
			,
			 
			momentum
			
			;
			 
			
			kw
			...
			)
			 
  
			
			function
			 
			
			batchnorm_pullback
			(
			Δ
			)
			
			
    
			
			grad
			 
			=
			 
			
			∇batchnorm
			(
			g
			,
			 
			b
			,
			 
			x
			,
			 
			
			unthunk
			(
			Δ
			)
			,
			 
			running_mean
			,
			 
			running_var
			,
			 
			momentum
			
			;
			 
			
			kw
			...
			)
			
    
			
			(
			
			NoTangent
			(
			)
			,
			 
			
			grad
			...
			,
			 
			
			NoTangent
			(
			)
			,
			 
			
			NoTangent
			(
			)
			,
			 
			
			NoTangent
			(
			)
			)
			
  
			end
			
  
			
			y
			,
			 
			batchnorm_pullback
			

			end