functor.jl

Flux/functor.jl is a source file in module Flux

			
			
			
			
			import
			
			 
			Adapt
			:
			
			 
			adapt
			,
			
			 
			adapt_storage
			

			
			
			using
			
			  
			LinearAlgebra
			:
			
			 
			Cholesky
			

			
			
			using
			
			 
			Zygote
			:
			
			 
			IdSet
			

			
			
			import
			
			 
			Functors
			:
			
			 
			Functors
			,
			
			 
			@
			functor
			,
			
			 
			functor
			,
			
			 
			fmap
			,
			
			 
			isleaf
			

			
			
			using
			
			 
			SparseArrays
			:
			
			 
			AbstractSparseArray
			

			

			
			
			
			"""
			

			    testmode!(model, [mode]) -> model

			

			Set a layer, or all layers in a model, to test mode.

			This disables the effect of [`Dropout`](@ref) and

			some other regularisation layers.

			

			If you manually set a model into test mode, you need to manually place

			it back into train mode during training phase, using [`trainmode!`](@ref).

			

			There is an optional second argument, which takes a symbol `:auto` to

			reset all layers back to the default automatic mode.

			

			# Example

			

			```jldoctest

			julia> d = Dropout(0.3)

			Dropout(0.3)

			

			julia> testmode!(d)   # dropout is now always disabled

			Dropout(0.3, active=false)

			

			julia> trainmode!(d)  # dropout is now always enabled

			Dropout(0.3, active=true)

			

			julia> testmode!(d, :auto)  # back to default

			Dropout(0.3)

			```

			"""
			

			
			

	
			testmode!
			(
			m
			)
			 
			=
			 
			

	
			testmode!
			(
			m
			,
			 
			true
			)
			

			

			
			
			
			"""
			

			    trainmode!(model) -> model

			

			Set a layer, or all layers in a model, to training mode.

			Opposite to [`testmode!`](@ref), see further details there.

			"""
			

			
			

	
			trainmode!
			(
			m
			)
			 
			=
			 
			

	
			testmode!
			(
			m
			,
			 
			false
			)
			

			
			

	
			trainmode!
			(
			m
			,
			 
			
			mode
			::
			Symbol
			)
			 
			=
			 
			

	
			testmode!
			(
			m
			,
			 
			mode
			)
			

			
			

	
			trainmode!
			(
			m
			,
			 
			
			::
			Nothing
			)
			 
			=
			 
			

	
			testmode!
			(
			m
			,
			 
			nothing
			)

why do we have so much API?


			
			
			
			
			
			"""
			

			    testmode!(model, inactive)

			

			This two-argument method is largely internal. It recurses into the `model`,

			and until a method like `testmode!(d::Dropout, inactive)` alters the activity of a layer.

			Custom layers can support manual `testmode!` / `trainmode!` switching

			by defining such a method.

			

			Possible values of  `inactive` are:

			- `true` for testing, i.e. `active=false`

			- `false` for training, same as [`trainmode!`](@ref)`(m)`

			- `:auto` or `nothing` for Flux to detect training automatically.

			

			!!! compat

			    This method may be removed in a future breaking change, to separate

			    the user-facing `testmode!` from the internal recursion.

			"""
			

			
			function
			 
			

	
			testmode!
			(
			m
			,
			 
			mode
			)
			
			
  
			
			inactive
			 
			=
			 
			
			if
			
			 
			mode
			 
			isa
			 
			Symbol
			
			
    
			
			
			mode
			 
			===
			 
			
			:
			auto
			 
			||
			 
			
			throw
			(
			
			ArgumentError
			(
			
			"
			testmode! accepts only the symbol :auto, got :
			$
			mode
			"
			)
			)
			
    
			nothing
			
  
			
			elseif
			
			 
			mode
			 
			isa
			 
			
			Union
			{
			Bool
			,
			Nothing
			}
			
			
    
			mode
			
  
			else
			
			
    
			
			throw
			(
			
			ArgumentError
			(
			
			"
			testmode! does not accept 
			$
			(
			
			repr
			(
			mode
			)
			)
			 as the 2nd argument
			"
			)
			)
			
  
			end
			
  
			
			foreach
			(
			
			x
			 
			->
			 
			

	
			testmode!
			(
			x
			,
			 
			inactive
			)
			,
			 
			
			trainable
			(
			m
			)
			)
			
  
			m
			

			end
			

			

			
			function
			 
			

	
			params!
			(
			
			p
			::
			Params
			,
			 
			x
			,
			 
			
			seen
			 
			=
			 
			
			IdSet
			(
			)
			)
			
			
  
			
			if
			
			
			 
			x
			 
			isa
			 
			
			AbstractArray
			{
			
			<:
			Number
			}
			 
			&&
			 
			
			
			Functors
			.
			
			isleaf
			(
			x
			)
			
			
    
			
			return
			 
			
			push!
			(
			p
			,
			 
			x
			)
			
  
			
			elseif
			
			 
			x
			 
			in
			 
			seen
			
			
    
			nothing
			
  
			else
			
			
    
			
			push!
			(
			seen
			,
			 
			x
			)
			
    
			
			for
			
			 
			child
			 
			in
			 
			
			trainable
			(
			x
			)
			
			
      
			

	
			params!
			(
			p
			,
			 
			child
			,
			 
			seen
			)
			
    
			end
			
  
			end
			

			end
			

			

			
			
			
			"""
			

			    params(model)

			    params(layers...)

			

			Given a model or specific layers from a model, create a `Params` object pointing to its trainable parameters.

			

			This can be used with the `gradient` function, see the [training section of the manual](@ref man-training), or as input to the [`Flux.train!`](@ref Flux.train!) function.

			

			The behaviour of `params` on custom types can be customized using [`Functors.@functor`](@ref) or [`Flux.trainable`](@ref).

			

			# Examples

			```jldoctest

			julia> using Flux: params

			

			julia> params(Chain(Dense(ones(2,3)), softmax))  # unpacks Flux models

			Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])

			

			julia> bn = BatchNorm(2, relu)

			BatchNorm(2, relu)  # 4 parameters, plus 4 non-trainable

			

			julia> params(bn)  # only the trainable parameters

			Params([Float32[0.0, 0.0], Float32[1.0, 1.0]])

			

			julia> params([1, 2, 3], [4])  # one or more arrays of numbers

			Params([[1, 2, 3], [4]])

			

			julia> params([[1, 2, 3], [4]])  # unpacks array of arrays

			Params([[1, 2, 3], [4]])

			

			julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin))  # ignores scalars, unpacks NamedTuples

			Params([[2 2], [3, 3, 3]])

			```

			"""
			

			
			function
			 
			

	
			params
			(
			
			m
			...
			)
			
			
  
			
			ps
			 
			=
			 
			
			Params
			(
			)
			
  
			

	
			params!
			(
			ps
			,
			 
			m
			)
			
  
			
			return
			 
			ps
			

			end

Allows caching of the parameters when params is called within gradient() to fix #2040. params(m...) # https://github.com/FluxML/Flux.jl/pull/2054 That speeds up implicit use, and silently breaks explicit use. From Zygote. params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248


			
			
			
			
			
			Zygote
			.
			
			_pullback
			(
			
			::
			
			
			Zygote
			.
			
			Context
			{
			true
			}
			,
			 
			
			::
			
			typeof
			(

	
			params
			)
			,
			 
			
			m
			...
			)
			 
			=
			
			 
			

	
			params
			(
			m
			)
			,
			 
			
			_
			 
			->
			 
			nothing
			

			

			
			
			struct
			 

	
			FluxCUDAAdaptor
			
			 
			end
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCUDAAdaptor
			,
			 
			x
			)
			 
			=
			 
			
			
			CUDA
			.
			
			cu
			(
			x
			)
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCUDAAdaptor
			,
			 
			
			x
			::
			
			
			Zygote
			.
			
			FillArrays
			.
			
			AbstractFill
			)
			 
			=
			 
			
			
			CUDA
			.
			
			cu
			(
			
			collect
			(
			x
			)
			)
			

			
			if
			
			 
			VERSION
			 
			>=
			 
			
			v
			
			"
			1.7
			"
			
			
  
			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCUDAAdaptor
			,
			 
			
			x
			::
			
			Random
			.
			
			TaskLocalRNG
			)
			 
			=
			 
			
			
			CUDA
			.
			
			default_rng
			(
			)
			

			else
			
			
  
			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCUDAAdaptor
			,
			 
			
			x
			::
			
			Random
			.
			
			_GLOBAL_RNG
			)
			 
			=
			 
			
			
			CUDA
			.
			
			default_rng
			(
			)
			

			end
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCUDAAdaptor
			,
			 
			
			x
			::
			
			CUDA
			.
			
			RNG
			)
			 
			=
			 
			x
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCUDAAdaptor
			,
			 
			
			x
			::
			AbstractRNG
			)
			 
			=
			
  
			
			error
			(
			
			"
			Cannot map RNG of type 
			$
			(
			
			typeof
			(
			x
			)
			)
			 to GPU. GPU execution only supports Random.default_rng().
			"
			)

TODO: figure out the correct design for OneElement


			
			
			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCUDAAdaptor
			,
			 
			
			x
			::
			
			Zygote
			.
			
			OneElement
			)
			 
			=
			 
			
			
			CUDA
			.
			
			cu
			(
			
			collect
			(
			x
			)
			)
			

			

			
			
			struct
			 

	
			FluxCPUAdaptor
			
			 
			end

define rules for handling structured arrays


			
			
			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			AbstractArray
			)
			 
			=
			 
			
			adapt
			(
			Array
			,
			 
			x
			)
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			AbstractRange
			)
			 
			=
			 
			x
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			
			
			Zygote
			.
			
			FillArrays
			.
			
			AbstractFill
			)
			 
			=
			 
			x
			

			
			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			T
			)
			 
			where
			 
			
			T
			 
			<:
			 
			
			
			
			
			CUDA
			.
			
			CUSPARSE
			.
			
			CUDA
			.
			
			CUSPARSE
			.
			
			AbstractCuSparseMatrix
			 
			=
			 
			
			adapt
			(
			Array
			,
			 
			x
			)
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			
			Zygote
			.
			
			OneElement
			)
			 
			=
			 
			x
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			AbstractSparseArray
			)
			 
			=
			 
			x
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			
			CUDA
			.
			
			RNG
			)
			 
			=
			 
			
			
			Random
			.
			
			default_rng
			(
			)
			

			
			
			adapt_storage
			(
			
			to
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			AbstractRNG
			)
			 
			=
			 
			x
			

			

			
			function
			 
			
			
			ChainRulesCore
			.
			
			rrule
			(
			
			::
			
			typeof
			(
			
			Adapt
			.
			
			adapt_storage
			)
			,
			 
			
			to
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			
			CUDA
			.
			
			AbstractGPUArray
			)
			
			
  
			
			
			adapt_storage
			(
			to
			,
			 
			x
			)
			,
			 
			
			dx
			 
			->
			 
			
			(
			
			NoTangent
			(
			)
			,
			 
			
			NoTangent
			(
			)
			,
			 
			
			adapt_storage
			(
			

	
			FluxCUDAAdaptor
			(
			)
			,
			 
			
			unthunk
			(
			dx
			)
			)
			)
			

			end

The following rrules for adapt are here to avoid double wrapping issues as seen in https://github.com/FluxML/Flux.jl/pull/2117#discussion_r1027321801


			
			
			
			
			
			ChainRulesCore
			.
			
			rrule
			(
			
			::
			
			typeof
			(
			adapt
			)
			,
			 
			
			a
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			AnyCuArray
			)
			 
			=
			
			
  
			
			adapt
			(
			a
			,
			 
			x
			)
			,
			 
			
			Δ
			 
			->
			 
			
			(
			
			NoTangent
			(
			)
			,
			 
			
			NoTangent
			(
			)
			,
			 
			
			adapt
			(
			

	
			FluxCUDAAdaptor
			(
			)
			,
			 
			
			unthunk
			(
			Δ
			)
			)
			)
			

			

			
			
			
			ChainRulesCore
			.
			
			rrule
			(
			
			::
			
			typeof
			(
			adapt
			)
			,
			 
			
			a
			::

	
			FluxCPUAdaptor
			,
			 
			
			x
			::
			AbstractArray
			)
			 
			=
			
			
  
			
			adapt
			(
			a
			,
			 
			x
			)
			,
			 
			
			Δ
			 
			->
			 
			
			(
			
			NoTangent
			(
			)
			,
			 
			
			NoTangent
			(
			)
			,
			 
			Δ
			)
			

			

			
			
			
			ChainRulesCore
			.
			
			rrule
			(
			
			::
			
			typeof
			(
			adapt
			)
			,
			 
			
			a
			::

	
			FluxCUDAAdaptor
			,
			 
			
			x
			::
			AnyCuArray
			)
			 
			=
			
			
  
			
			adapt
			(
			a
			,
			 
			x
			)
			,
			 
			
			Δ
			 
			->
			 
			
			(
			
			NoTangent
			(
			)
			,
			 
			
			NoTangent
			(
			)
			,
			 
			Δ
			)
			

			

			
			
			
			ChainRulesCore
			.
			
			rrule
			(
			
			::
			
			typeof
			(
			adapt
			)
			,
			 
			
			a
			::

	
			FluxCUDAAdaptor
			,
			 
			
			x
			::
			AbstractArray
			)
			 
			=
			
			
  
			
			adapt
			(
			a
			,
			 
			x
			)
			,
			 
			
			Δ
			 
			->
			 
			
			(
			
			NoTangent
			(
			)
			,
			 
			
			NoTangent
			(
			)
			,
			 
			
			adapt
			(
			

	
			FluxCPUAdaptor
			(
			)
			,
			 
			
			unthunk
			(
			Δ
			)
			)
			)

CPU/GPU movement conveniences


			
			
			
			
			
			"""
			

			    cpu(m)

			

			Copies `m` onto the CPU, the opposite of [`gpu`](@ref).

			Recurses into structs marked [`@functor`](@ref).

			

			# Example

			```julia-repl

			julia> m_gpu = Dense(CUDA.randn(2, 5))

			Dense(5 => 2)       # 12 parameters

			

			julia> m_gpu.bias  # matches the given weight matrix

			2-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:

			 0.0

			 0.0

			

			julia> m = m_gpu |> cpu

			Dense(5 => 2)       # 12 parameters

			

			julia> m.bias

			2-element Vector{Float32}:

			 0.0

			 0.0

			```

			"""
			

			
			

	
			cpu
			(
			x
			)
			 
			=
			 
			
			fmap
			(
			
			x
			 
			->
			 
			
			adapt
			(
			

	
			FluxCPUAdaptor
			(
			)
			,
			 
			x
			)
			,
			 
			x
			,
			 
			
			exclude
			 
			=
			 
			_isleaf
			)
			

			

			
			
			_isbitsarray
			(
			
			::
			
			AbstractArray
			{
			
			<:
			Number
			}
			)
			 
			=
			 
			true
			

			
			
			
			_isbitsarray
			(
			
			::
			
			AbstractArray
			{
			T
			}
			)
			 
			where
			 
			T
			 
			=
			 
			
			isbitstype
			(
			T
			)
			

			
			
			_isbitsarray
			(
			x
			)
			 
			=
			 
			false
			

			

			
			
			_isleaf
			(
			
			::
			AbstractRNG
			)
			 
			=
			 
			true
			

			
			
			_isleaf
			(
			x
			)
			 
			=
			
			 
			
			_isbitsarray
			(
			x
			)
			 
			||
			 
			
			
			Functors
			.
			
			isleaf
			(
			x
			)
			

			

			
			const
			
			 

	
			GPU_BACKENDS
			 
			=
			 
			
			(
			
			"
			CUDA
			"
			,
			 
			
			"
			AMD
			"
			,
			 
			
			"
			Metal
			"
			)
			

			
			const
			
			 

	
			GPU_BACKEND
			 
			=
			 
			
			@
			load_preference
			(
			
			"
			gpu_backend
			"
			,
			 
			
			"
			CUDA
			"
			)
			

			

			
			function
			 
			

	
			gpu_backend!
			(
			
			backend
			::
			String
			)
			
			
    
			
			if
			
			 
			backend
			 
			==
			 

	
			GPU_BACKEND
			
			
        
			
			@
			info
			 
			
			"""
			

			        
			GPU backend is already set to: 
			$
			backend
			.

			        
			No need to do anything else.

			        
			"""
			
        
			
			return
			
			
    
			end
			

			
    
			
			
			backend
			 
			in
			 

	
			GPU_BACKENDS
			 
			||
			 
			
			throw
			(
			
			ArgumentError
			(
			
			"""
			

			    
			Unsupported GPU backend: 
			$
			backend
			.

			    
			Supported backends are: 
			$

	
			GPU_BACKENDS
			.

			    
			"""
			)
			)
			

			
    
			
			@
			set_preferences!
			(
			
			
			"
			gpu_backend
			"
			 
			=>
			 
			backend
			)
			
    
			
			@
			info
			 
			
			"""
			

			    
			New GPU backend set: 
			$
			backend
			.

			    
			Restart your Julia session for this change to take effect!

			    
			"""
			

			end
			

			

			
			
			
			"""
			

			    gpu(m)

			

			Copies `m` to the current GPU device (using current GPU backend), if one is available.

			If no GPU is available, it does nothing (but prints a warning the first time).

			

			On arrays, this calls CUDA's `cu`, which also changes arrays

			with Float64 elements to Float32 while copying them to the device (same for AMDGPU).

			To act on arrays within a struct, the struct type must be marked with [`@functor`](@ref).

			

			Use [`cpu`](@ref) to copy back to ordinary `Array`s.

			See also [`f32`](@ref) and [`f16`](@ref) to change element type only.

			

			See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/) 

			to help identify the current device.

			

			# Example

			```julia-repl

			julia> m = Dense(rand(2, 3))  # constructed with Float64 weight matrix

			Dense(3 => 2)       # 8 parameters

			

			julia> typeof(m.weight)

			Matrix{Float64} (alias for Array{Float64, 2})

			

			julia> m_gpu = gpu(m)  # can equivalently be written m_gpu = m |> gpu

			Dense(3 => 2)       # 8 parameters

			

			julia> typeof(m_gpu.weight)

			CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}

			```

			"""
			

			
			function
			 
			

	
			gpu
			(
			x
			)
			
			
    
			
			@
			static
			 
			
			if
			
			 

	
			GPU_BACKEND
			 
			==
			 
			
			"
			CUDA
			"
			
			
        
			

	
			gpu
			(
			

	
			FluxCUDAAdaptor
			(
			)
			,
			 
			x
			)
			
    
			
			elseif
			
			 

	
			GPU_BACKEND
			 
			==
			 
			
			"
			AMD
			"
			
			
        
			

	
			gpu
			(
			

	
			FluxAMDAdaptor
			(
			)
			,
			 
			x
			)
			
    
			
			elseif
			
			 

	
			GPU_BACKEND
			 
			==
			 
			
			"
			Metal
			"
			
			
        
			

	
			gpu
			(
			

	
			FluxMetalAdaptor
			(
			)
			,
			 
			x
			)
			
    
			else
			
			
        
			
			error
			(
			
			"""
			

			        
			Unsupported GPU backend: 
			$

	
			GPU_BACKEND
			.

			        
			Supported backends are: 
			$

	
			GPU_BACKENDS
			.

			        
			"""
			)
			
    
			end
			

			end
			

			

			
			function
			 
			

	
			gpu
			(
			
			::

	
			FluxCUDAAdaptor
			,
			 
			x
			)
			
			
  
			

	
			check_use_cuda
			(
			)
			
  
			
			

	
			use_cuda
			[
			]
			 
			?
			 
			
			fmap
			(
			
			x
			 
			->
			 
			
			
			Adapt
			.
			
			adapt
			(
			

	
			FluxCUDAAdaptor
			(
			)
			,
			 
			x
			)
			,
			 
			x
			
			;
			 
			
			exclude
			 
			=
			 
			_isleaf
			)
			 
			:
			 
			x
			

			end
			

			

			
			function
			 
			

	
			check_use_cuda
			(
			)
			
			
  
			
			if
			
			 
			

	
			use_cuda
			[
			]
			 
			===
			 
			nothing
			
			
    
			
			

	
			use_cuda
			[
			]
			 
			=
			 
			
			
			CUDA
			.
			
			functional
			(
			)
			
    
			
			if
			
			 
			

	
			use_cuda
			[
			]
			 
			&&
			 
			
			!
			
			
			cuDNN
			.
			
			has_cudnn
			(
			)
			
			
      
			
			@
			warn
			 
			
			"
			CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available.
			"
			
			  
			maxlog
			=
			1
			
    
			end
			
    
			
			if
			 
			
			!
			(
			

	
			use_cuda
			[
			]
			)
			
			
      
			
			@
			info
			 
			
			"""
			The GPU function is being called but the GPU is not accessible. 

			               
			Defaulting back to the CPU. (No action is required if you want to run on the CPU).
			"""
			
			 
			maxlog
			=
			1
			
    
			end
			
  
			end
			

			end
			

			

			
			
			ChainRulesCore
			.
			
			@
			non_differentiable
			 
			

	
			check_use_cuda
			(
			)

Precision


			
			
			
			
			struct
			 
			

	
			FluxEltypeAdaptor
			{
			T
			}
			
			 
			end
			

			

			
			
			
			
			Adapt
			.
			
			adapt_storage
			(
			
			::
			

	
			FluxEltypeAdaptor
			{
			T
			}
			,
			 
			
			x
			::
			
			AbstractArray
			{
			
			<:
			AbstractFloat
			}
			)
			 
			where
			 
			{
			
			T
			<:
			AbstractFloat
			}
			 
			=
			 
  
			
			convert
			(
			
			AbstractArray
			{
			T
			}
			,
			 
			x
			)
			

			
			
			
			
			Adapt
			.
			
			adapt_storage
			(
			
			::
			

	
			FluxEltypeAdaptor
			{
			T
			}
			,
			 
			
			x
			::
			
			AbstractArray
			{
			
			<:
			
			Complex
			{
			
			<:
			AbstractFloat
			}
			}
			)
			 
			where
			 
			{
			
			T
			<:
			AbstractFloat
			}
			 
			=
			 
  
			
			convert
			(
			
			AbstractArray
			{
			
			Complex
			{
			T
			}
			}
			,
			 
			x
			)
			

			

			
			
			
			_paramtype
			(
			
			::
			
			Type
			{
			T
			}
			,
			 
			m
			)
			 
			where
			 
			T
			 
			=
			 
			
			fmap
			(
			
			adapt
			(
			
			

	
			FluxEltypeAdaptor
			{
			T
			}
			(
			)
			)
			,
			 
			m
			)

fastpath for arrays


			
			
			
			
			
			_paramtype
			(
			
			::
			
			Type
			{
			T
			}
			,
			 
			
			x
			::
			
			AbstractArray
			{
			
			<:
			AbstractFloat
			}
			)
			 
			where
			 
			{
			
			T
			<:
			AbstractFloat
			}
			 
			=
			 
  
			
			convert
			(
			
			AbstractArray
			{
			T
			}
			,
			 
			x
			)
			

			
			
			
			_paramtype
			(
			
			::
			
			Type
			{
			T
			}
			,
			 
			
			x
			::
			
			AbstractArray
			{
			
			<:
			
			Complex
			{
			
			<:
			AbstractFloat
			}
			}
			)
			 
			where
			 
			{
			
			T
			<:
			AbstractFloat
			}
			 
			=
			 
  
			
			convert
			(
			
			AbstractArray
			{
			
			Complex
			{
			T
			}
			}
			,
			 
			x
			)
			

			

			
			
			
			"""
			

			    f32(m)

			

			Converts the `eltype` of model's *floating point* parameters to `Float32` (which is Flux's default).

			Recurses into structs marked with [`@functor`](@ref).

			

			See also [`f64`](@ref) and [`f16`](@ref).

			"""
			

			
			

	
			f32
			(
			m
			)
			 
			=
			 
			
			_paramtype
			(
			Float32
			,
			 
			m
			)
			

			

			
			
			
			"""
			

			    f64(m)

			

			Converts the `eltype` of model's *floating point* parameters to `Float64`.

			Recurses into structs marked with [`@functor`](@ref).

			

			See also [`f32`](@ref) and [`f16`](@ref).

			"""
			

			
			

	
			f64
			(
			m
			)
			 
			=
			 
			
			_paramtype
			(
			Float64
			,
			 
			m
			)
			

			

			
			
			
			"""
			

			    f16(m)

			

			Converts the `eltype` of model's *floating point* parameters to `Float16`.

			Recurses into structs marked with [`@functor`](@ref).

			

			Support for `Float16` is limited on many CPUs. Julia may

			convert to `Float32` for each operation, which is slow.

			

			See also [`f32`](@ref) and [`f64`](@ref).

			

			# Example

			```jldoctest

			julia> m = Chain(Dense(784, 2048, relu), Dense(2048, 10))  # all Float32

			Chain(

			  Dense(784 => 2048, relu),             # 1_607_680 parameters

			  Dense(2048 => 10),                    # 20_490 parameters

			)                   # Total: 4 arrays, 1_628_170 parameters, 6.211 MiB.

			

			julia> m |> f16  # takes half the memory

			Chain(

			  Dense(784 => 2048, relu),             # 1_607_680 parameters

			  Dense(2048 => 10),                    # 20_490 parameters

			)                   # Total: 4 arrays, 1_628_170 parameters, 3.106 MiB.

			```

			"""
			

			
			

	
			f16
			(
			m
			)
			 
			=
			 
			
			_paramtype
			(
			Float16
			,
			 
			m
			)

Functors for certain Julia data structures


			
			
			
			@
			functor
			 
			Cholesky
			

			
			
			trainable
			(
			
			c
			::
			Cholesky
			)
			 
			=
			 
			
			(
			)

AMDGPU extension.


			
			
			
			
			struct
			 

	
			FluxAMDAdaptor
			
			 
			end
			

			

			
			const
			
			 

	
			AMDGPU_LOADED
			 
			=
			 
			
			
			Ref
			{
			Bool
			}
			(
			false
			)
			

			

			
			function
			 
			

	
			gpu
			(
			
			::

	
			FluxAMDAdaptor
			,
			 
			x
			)
			
			
    
			
			if
			 
			

	
			AMDGPU_LOADED
			[
			]
			
			
        
			
			return
			 
			
			_amd
			(
			x
			)
			
    
			else
			
			
        
			
			@
			info
			 
			
			"""
			

			        
			The AMDGPU functionality is being called via `Flux.amd` but

			        
			`AMDGPU` must be loaded to access it.

			        
			"""
			
			 
			maxlog
			=
			1
			
    
			end
			

			end
			

			

			
			function
			 
			_amd
			 
			end

Metal extension.


			
			
			
			
			struct
			 

	
			FluxMetalAdaptor
			
			 
			end
			

			

			
			const
			
			 

	
			METAL_LOADED
			 
			=
			 
			
			
			Ref
			{
			Bool
			}
			(
			false
			)
			

			

			
			function
			 
			

	
			gpu
			(
			
			::

	
			FluxMetalAdaptor
			,
			 
			x
			)
			
			
    
			
			if
			 
			

	
			METAL_LOADED
			[
			]
			
			
        
			
			return
			 
			
			_metal
			(
			x
			)
			
    
			else
			
			
        
			
			@
			info
			 
			
			"""
			

			        
			The Metal functionality is being called but

			        
			`Metal.jl` must be loaded to access it.

			        
			"""
			
			 
			maxlog
			=
			1
			
    
			end
			

			end
			

			

			
			function
			 
			_metal
			 
			end
			

			

			

			
			
			
			"""
			

			    gpu(data::DataLoader)

			

			Transforms a given `DataLoader` to apply `gpu` to each batch of data,

			when iterated over. (If no GPU is available, this does nothing.)

			

			# Example

			

			```julia-repl

			julia> dl = Flux.DataLoader((x = ones(2,10), y='a':'j'), batchsize=3)

			4-element DataLoader(::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}, batchsize=3)

			  with first element:

			  (; x = 2×3 Matrix{Float64}, y = 3-element StepRange{Char, Int64})

			

			julia> first(dl)

			(x = [1.0 1.0 1.0; 1.0 1.0 1.0], y = 'a':1:'c')

			

			julia> c_dl = gpu(dl)

			4-element DataLoader(::MLUtils.MappedData{:auto, typeof(gpu), NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}}, batchsize=3)

			  with first element:

			  (; x = 2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element StepRange{Char, Int64})

			

			julia> first(c_dl).x

			2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:

			 1.0  1.0  1.0

			 1.0  1.0  1.0

			```

			

			For large datasets, this is preferred over moving all the data to

			the GPU before creating the `DataLoader`, like this:

			

			```julia-repl

			julia> Flux.DataLoader((x = ones(2,10), y=2:11) |> gpu, batchsize=3)

			4-element DataLoader(::NamedTuple{(:x, :y), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, UnitRange{Int64}}}, batchsize=3)

			  with first element:

			  (; x = 2×3 CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element UnitRange{Int64})

			```

			

			!!! warning

			    This only works if `gpu` is applied directly to the `DataLoader`.

			    While `gpu` acts recursively on Flux models and many basic Julia structs,

			    it will not work on (say) a tuple of `DataLoader`s.

			"""
			

			
			function
			 
			

	
			gpu
			(
			
			d
			::
			

	
			MLUtils
			.
			

	
			DataLoader
			)
			
			
  
			
			

	
			MLUtils
			.
			

	
			DataLoader
			(
			
			

	
			MLUtils
			.
			

	
			mapobs
			(

	
			gpu
			,
			 
			
			d
			.
			
			data
			)
			,
			
    
			
			d
			.
			

	
			batchsize
			,
			
    
			
			d
			.
			
			buffer
			,
			
    
			
			d
			.
			
			partial
			,
			
    
			
			d
			.
			
			shuffle
			,
			
    
			
			d
			.
			
			parallel
			,
			
    
			
			d
			.
			
			collate
			,
			
    
			
			d
			.
			
			rng
			,
			
  
			)
			

			end