From 99c33d011d63174f50cea37c3eede002958be6d3 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 11 Jun 2025 02:51:51 +0900 Subject: [PATCH] rebase + nit (#2260) Co-authored-by: Awni Hannun --- mlx/backend/cuda/CMakeLists.txt | 15 ++- mlx/backend/cuda/copy.cpp | 26 ----- mlx/backend/cuda/copy.cu | 89 +++++++++++++++ mlx/backend/cuda/copy/copy.cuh | 71 ++++++++++++ mlx/backend/cuda/copy/copy_contiguous.cu | 56 ++++++++++ mlx/backend/cuda/copy/copy_general.cu | 95 ++++++++++++++++ mlx/backend/cuda/copy/copy_general_dynamic.cu | 105 ++++++++++++++++++ mlx/backend/cuda/copy/copy_general_input.cu | 88 +++++++++++++++ mlx/backend/cuda/kernels/cast_op.cuh | 59 ++++++++++ mlx/backend/cuda/slicing.cpp | 28 ++++- 10 files changed, 604 insertions(+), 28 deletions(-) delete mode 100644 mlx/backend/cuda/copy.cpp create mode 100644 mlx/backend/cuda/copy.cu create mode 100644 mlx/backend/cuda/copy/copy.cuh create mode 100644 mlx/backend/cuda/copy/copy_contiguous.cu create mode 100644 mlx/backend/cuda/copy/copy_general.cu create mode 100644 mlx/backend/cuda/copy/copy_general_dynamic.cu create mode 100644 mlx/backend/cuda/copy/copy_general_input.cu create mode 100644 mlx/backend/cuda/kernels/cast_op.cuh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 6ca176ceb..7ffbcb2d3 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -7,7 +7,11 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu - ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu @@ -28,6 +32,15 @@ target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") +# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive. +# Explicitly pass this flag to suppress the warning, it is safe to set it to +# true but the warning wouldn't be suppressed. +if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) + target_compile_options( + mlx + PRIVATE "$<$:--static-global-template-stub=false>") +endif() + # Compute capability 7 is required for synchronization between CPU/GPU with # managed memory. TODO: Add more architectures for potential performance gain. set(MLX_CUDA_ARCHITECTURES diff --git a/mlx/backend/cuda/copy.cpp b/mlx/backend/cuda/copy.cpp deleted file mode 100644 index d0413d989..000000000 --- a/mlx/backend/cuda/copy.cpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/gpu/copy.h" - -namespace mlx::core { - -void copy_gpu_inplace( - const array& in, - array& out, - const Shape& data_shape, - const Strides& strides_in_pre, - const Strides& strides_out_pre, - int64_t inp_offset, - int64_t out_offset, - CopyType ctype, - const Stream& s, - const std::optional& dynamic_i_offset /* = std::nullopt */, - const std::optional& dynamic_o_offset /* = std::nullopt */) { - throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend."); -} - -void fill_gpu(const array& val, array& out, const Stream& s) { - throw std::runtime_error("fill_gpu not implemented in CUDA backend."); -} - -} // namespace mlx::core diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu new file mode 100644 index 000000000..8649e1bf9 --- /dev/null +++ b/mlx/backend/cuda/copy.cu @@ -0,0 +1,89 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/copy/copy.cuh" + +namespace mlx::core { + +void copy_gpu_inplace( + const array& in_, + array& out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + int64_t offset_in, + int64_t offset_out, + CopyType ctype, + const Stream& s, + const std::optional& dynamic_offset_in, + const std::optional& dynamic_offset_out) { + if (out.size() == 0) { + return; + } + const array& in = in_.data_shared_ptr() ? in_ : out; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { + copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); + return; + } + + if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { + auto [shape_collapsed, strides_vec] = collapse_contiguous_dims( + shape, std::vector{strides_in, strides_out}, INT32_MAX); + if (ctype == CopyType::General) { + copy_general_input( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0]); + } else { + if (dynamic_offset_in || dynamic_offset_out) { + copy_general_dynamic( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0], + strides_vec[1], + dynamic_offset_in ? *dynamic_offset_in : array(0, int64), + dynamic_offset_out ? *dynamic_offset_out : array(0, int64)); + } else { + copy_general( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0], + strides_vec[1]); + } + } + return; + } +} + +void fill_gpu(const array& in, array& out, const Stream& s) { + if (out.size() == 0) { + return; + } + out.set_data(allocator::malloc(out.nbytes())); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh new file mode 100644 index 000000000..dd1d09d30 --- /dev/null +++ b/mlx/backend/cuda/copy/copy.cuh @@ -0,0 +1,71 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" + +namespace mlx::core { + +#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ + using InType = cuda_type_t; \ + using OutType = cuda_type_t; \ + if constexpr (cu::CastOp::is_castable) { \ + __VA_ARGS__; \ + } else { \ + throw std::runtime_error(fmt::format( \ + "Can not copy data from dtype {} to {}.", \ + dtype_to_string(out.dtype()), \ + dtype_to_string(in.dtype()))); \ + } \ + }); \ + }) + +void copy_contiguous( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out); + +void copy_general( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out); + +void copy_general_dynamic( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + const array& dynamic_offset_in, + const array& dynamic_offset_out); + +void copy_general_input( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu new file mode 100644 index 000000000..fa79f0604 --- /dev/null +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -0,0 +1,56 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_s(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = CastOp{}(in[0]); + } +} + +template +__global__ void copy_v(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = CastOp{}(in[index]); + } +} + +} // namespace cu + +void copy_contiguous( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t in_offset, + int64_t out_offset) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { + MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = cu::copy_s; + if (ctype == CopyType::Vector) { + kernel = cu::copy_v; + } + auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + kernel<<>>( + in.data() + in_offset, + out.data() + out_offset, + out.data_size()); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu new file mode 100644 index 000000000..3c5b3bbb3 --- /dev/null +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -0,0 +1,95 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_gg_nd( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array strides_in, + const __grid_constant__ cuda::std::array strides_out) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_nd( + index, shape.data(), strides_in.data(), strides_out.data()); + out[idx_out] = CastOp{}(in[idx_in]); + } +} + +template +__global__ void copy_gg( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides_in, + const __grid_constant__ Strides strides_out, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_4d( + index, shape.data(), strides_in.data(), strides_out.data(), ndim); + out[idx_out] = CastOp{}(in[idx_in]); + } +} + +} // namespace cu + +void copy_general( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::copy_gg_nd; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu new file mode 100644 index 000000000..b9774662a --- /dev/null +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -0,0 +1,105 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_gg_dynamic_nd( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array strides_in, + const __grid_constant__ cuda::std::array strides_out, + const int64_t* offset_in, + const int64_t* offset_out) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_nd( + index, shape.data(), strides_in.data(), strides_out.data()); + out[idx_out + *offset_out] = CastOp{}(in[idx_in + *offset_in]); + } +} + +template +__global__ void copy_gg_dynamic( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides_in, + const __grid_constant__ Strides strides_out, + int ndim, + const int64_t* offset_in, + const int64_t* offset_out) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_4d( + index, shape.data(), strides_in.data(), strides_out.data(), ndim); + out[idx_out + *offset_out] = CastOp{}(in[idx_in + *offset_in]); + } +} + +} // namespace cu + +void copy_general_dynamic( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + const array& dynamic_offset_in, + const array& dynamic_offset_out) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::copy_gg_dynamic_nd; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + dynamic_offset_in.data(), + dynamic_offset_out.data()); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, + dynamic_offset_in.data(), + dynamic_offset_out.data()); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu new file mode 100644 index 000000000..4f2784927 --- /dev/null +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -0,0 +1,88 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_g_nd( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array strides_in) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + IdxT idx_in = elem_to_loc_nd(index, shape.data(), strides_in.data()); + out[index] = CastOp{}(in[idx_in]); + } +} + +template +__global__ void copy_g( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides_in, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim); + out[index] = CastOp{}(in[idx_in]); + } +} + +} // namespace cu + +void copy_general_input( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::copy_g_nd; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + ndim); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/kernels/cast_op.cuh b/mlx/backend/cuda/kernels/cast_op.cuh new file mode 100644 index 000000000..30b44d46f --- /dev/null +++ b/mlx/backend/cuda/kernels/cast_op.cuh @@ -0,0 +1,59 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::cu { + +// An op that does static_cast, with custom conversions for some types. +template +struct CastOp { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ DstT operator()(SrcT x) { + return static_cast(x); + } +}; + +// Converting a complex number to real number discards the imaginary part. +template +struct CastOp< + cuComplex, + DstT, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ DstT operator()(cuComplex x) { + static_assert(!cuda::std::is_same_v); + return static_cast(cuCrealf(x)); + } +}; + +// Allow converting a real number to complex number. +template +struct CastOp< + SrcT, + cuComplex, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ cuComplex operator()(SrcT x) { + static_assert(!cuda::std::is_same_v); + return cuComplex{static_cast(x), 0}; + } +}; + +// Return an iterator that cast the value to DstT using CastOp. +template +__host__ __device__ auto make_cast_iterator(Iterator it) { + using SrcT = typename cuda::std::iterator_traits::value_type; + if constexpr (std::is_same_v) { + return it; + } else { + return thrust::make_transform_iterator(it, CastOp{}); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp index bfa742c74..af67fbbdd 100644 --- a/mlx/backend/cuda/slicing.cpp +++ b/mlx/backend/cuda/slicing.cpp @@ -1,7 +1,11 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" +#include + namespace mlx::core { void concatenate_gpu( @@ -9,7 +13,29 @@ void concatenate_gpu( array& out, int axis, const Stream& s) { - throw std::runtime_error("concatenate_gpu not implemented in CUDA backend."); + std::vector sizes; + sizes.push_back(0); + for (auto& p : inputs) { + sizes.push_back(p.shape(axis)); + } + std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); + + out.set_data(allocator::malloc(out.nbytes())); + + auto strides = out.strides(); + auto flags = out.flags(); + flags.row_contiguous = false; + flags.col_contiguous = false; + flags.contiguous = false; + // TODO: Handle concurrent outputs: + // https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816 + for (int i = 0; i < inputs.size(); i++) { + array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); + size_t data_offset = strides[axis] * sizes[i]; + out_slice.copy_shared_buffer( + out, strides, flags, out_slice.size(), data_offset); + copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s); + } } } // namespace mlx::core