Transferring data across devices
Flux relies on the MLDataDevices.jl package to manage devices and transfer data across them. You don't have to explicitly use the package, as Flux re-exports the necessary functions and types.
MLDataDevices.cpu_device — Functioncpu_device(eltype=missing) -> CPUDeviceReturn a CPUDevice object which can be used to transfer data to CPU.
The eltype parameter controls element type conversion:
missing/nothing(default): Preserves the original element typeType{<:AbstractFloat}: Converts floating-point arrays to the specified type
MLDataDevices.default_device_rng — Functiondefault_device_rng(::AbstractDevice)Returns the default RNG for the device. This can be used to directly generate parameters and states on the device using WeightInitializers.jl.
MLDataDevices.functional — Functionfunctional(x::AbstractDevice) -> Bool
functional(::Type{<:AbstractDevice}) -> BoolChecks if the device is functional. This is used to determine if the device can be used for computation. Note that even if the backend is loaded (as checked via MLDataDevices.loaded), the device may not be functional.
Note that while this function is not exported, it is considered part of the public API.
MLDataDevices.get_device — Functionget_device(x) -> dev::AbstractDevice | Exception | NothingIf all arrays (on the leaves of the structure) are on the same device, we return that device. Otherwise, we throw an error. If the object is device agnostic, we return nothing.
Special Retuened Values
nothing– denotes that the object is device agnostic. For example, scalar, abstract range, etc.UnknownDevice()– denotes that the device type is unknown.
See also get_device_type for a faster alternative that can be used for dispatch based on device type.
MLDataDevices.gpu_device — Functiongpu_device(
eltype::Union{Missing, Nothing, Type{<:AbstractFloat}}=missing;
kwargs...
) -> AbstractDevice
gpu_device(
device_id::Union{Nothing, Integer}=nothing
eltype::Union{Missing, Nothing, Type{<:AbstractFloat}}=missing;
force::Bool=false
) -> AbstractDeviceSelects GPU device based on the following criteria:
- If
gpu_backendpreference is set and the backend is functional on the system, then that device is selected. - Otherwise, an automatic selection algorithm is used. We go over possible device backends in the order specified by
supported_gpu_backends()and select the first functional backend. - If no GPU device is functional and
forceisfalse, thencpu_device()is invoked. - If nothing works, an error is thrown.
Arguments
device_id::Union{Nothing, Integer}: The device id to select. Ifnothing, then we return the last selected device or if none was selected then we run the autoselection and choose the current device usingCUDA.device()orAMDGPU.device()or similar. IfInteger, then we select the device with the given id. Note that this is1-indexed, in contrast to the0-indexedCUDA.jl. For example,id = 4corresponds toCUDA.device!(3).eltype::Union{Missing, Nothing, Type{<:AbstractFloat}}: The element type to use for the device.missing(default): Device specific. ForCUDADevicethis callsCUDA.cu(x), forAMDGPUDevicethis callsAMDGPU.roc(x), forMetalDevicethis callsMetal.mtl(x), foroneAPIDevicethis callsoneArray(x).nothing: Preserves the original element type.Type{<:AbstractFloat}: Converts floating-point arrays to the specified type.
device_id is only applicable for CUDA and AMDGPU backends. For Metal, oneAPI and CPU backends, device_id is ignored and a warning is printed.
gpu_device won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. This is to ensure that deep learning operations work correctly. Nonetheless, if cuDNN is not loaded you can still manually create a CUDADevice object and use it (e.g. dev = CUDADevice()).
Keyword Arguments
force::Bool: Iftrue, then an error is thrown if no functional GPU device is found.
MLDataDevices.gpu_backend! — Functiongpu_backend!() = gpu_backend!("")
gpu_backend!(backend) = gpu_backend!(string(backend))
gpu_backend!(backend::AbstractGPUDevice)
gpu_backend!(backend::String)Creates a LocalPreferences.toml file with the desired GPU backend.
If backend == "", then the gpu_backend preference is deleted. Otherwise, backend is validated to be one of the possible backends and the preference is set to backend.
If a new backend is successfully set, then the Julia session must be restarted for the change to take effect.
MLDataDevices.get_device_type — Functionget_device_type(x) -> Type{<:AbstractDevice} | Exception | Type{Nothing}Similar to get_device but returns the type of the device instead of the device itself. This value is often a compile time constant and is recommended to be used instead of get_device where ever defining dispatches based on the device type.
Special Retuened Values
Nothing– denotes that the object is device agnostic. For example, scalar, abstract range, etc.UnknownDevice– denotes that the device type is unknown.
MLDataDevices.isleaf — Functionisleaf(x) -> BoolReturns true if x is a leaf node in the data structure.
Defining MLDataDevices.isleaf(x::T) = true for custom types can be used to customize the behavior the data movement behavior when an object with nested structure containing the type is transferred to a device.
Adapt.adapt_structure(::AbstractDevice, x::T) or Adapt.adapt_structure(::AbstractDevice, x::T) will be called during data movement if isleaf(x::T).
If MLDataDevices.isleaf(x::T) is not defined, then it will fall back to Functors.isleaf(x).
MLDataDevices.loaded — Functionloaded(x::AbstractDevice) -> Bool
loaded(::Type{<:AbstractDevice}) -> BoolChecks if the trigger package for the device is loaded. Trigger packages are as follows:
CUDA.jlandcuDNN.jl(or justLuxCUDA.jl) for NVIDIA CUDA Support.AMDGPU.jlfor AMD GPU ROCM Support.Metal.jlfor Apple Metal GPU Support.oneAPI.jlfor Intel oneAPI GPU Support.
MLDataDevices.reset_gpu_device! — Functionreset_gpu_device!()Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again.
MLDataDevices.set_device! — Functionset_device!(T::Type{<:AbstractDevice}, dev_or_id)Set the device for the given type. This is a no-op for CPUDevice. For CUDADevice and AMDGPUDevice, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice and oneAPIDevice don't support setting the device.
Arguments
T::Type{<:AbstractDevice}: The device type to set.dev_or_id: Can be the device from the corresponding package. For example for CUDA it can be aCuDevice. If it is an integer, it is the device id to set. This is1-indexed.
set_device!(T::Type{<:AbstractDevice}, ::Nothing, rank::Integer)Set the device for the given type. This is a no-op for CPUDevice. For CUDADevice and AMDGPUDevice, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice and oneAPIDevice don't support setting the device.
Arguments
T::Type{<:AbstractDevice}: The device type to set.rank::Integer: Local Rank of the process. This is applicable for distributed training and must be0-indexed.
MLDataDevices.supported_gpu_backends — Functionsupported_gpu_backends() -> Tuple{String, ...}Return a tuple of supported GPU backends.
MLDataDevices.DeviceIterator — TypeDeviceIterator(dev::AbstractDevice, iterator)Create a DeviceIterator that iterates through the provided iterator via iterate. Upon each iteration, the current batch is copied to the device dev, and the previous iteration is marked as freeable from GPU memory (via unsafe_free!) (no-op for a CPU device).
The conversion follows the same semantics as dev(<item from iterator>).
The design inspiration was taken from CUDA.CuIterator and was generalized to work with other backends and more complex iterators (using Functors).
Calling dev(::MLUtils.DataLoader) will automatically convert the dataloader to use the same semantics as DeviceIterator. This is generally preferred over looping over the dataloader directly and transferring the data to the device.
Examples
The following was run on a computer with an NVIDIA GPU.
julia> using MLDataDevices, MLUtils
julia> X = rand(Float64, 3, 33);
julia> dataloader = DataLoader(X; batchsize=13, shuffle=false);
julia> for (i, x) in enumerate(dataloader)
@show i, summary(x)
end
(i, summary(x)) = (1, "3×13 Matrix{Float64}")
(i, summary(x)) = (2, "3×13 Matrix{Float64}")
(i, summary(x)) = (3, "3×7 Matrix{Float64}")
julia> for (i, x) in enumerate(CUDADevice()(dataloader))
@show i, summary(x)
end
(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}")