From 2232084f589d65a6971c172b2d391d39b64f9e90 Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 14 Apr 2025 00:20:19 +0000 Subject: [PATCH] CUDA backend: copy ops --- mlx/backend/cuda/CMakeLists.txt | 2 +- mlx/backend/cuda/copy.cpp | 26 ---- mlx/backend/cuda/copy.cu | 126 +++++++++++++++++++ mlx/backend/cuda/iterators/cast_iterator.cuh | 56 +++++++++ mlx/backend/cuda/slicing.cpp | 28 ++++- 5 files changed, 210 insertions(+), 28 deletions(-) delete mode 100644 mlx/backend/cuda/copy.cpp create mode 100644 mlx/backend/cuda/copy.cu create mode 100644 mlx/backend/cuda/iterators/cast_iterator.cuh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1a3e95059..a077baed0 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -7,7 +7,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu - ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu diff --git a/mlx/backend/cuda/copy.cpp b/mlx/backend/cuda/copy.cpp deleted file mode 100644 index d0413d989..000000000 --- a/mlx/backend/cuda/copy.cpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/gpu/copy.h" - -namespace mlx::core { - -void copy_gpu_inplace( - const array& in, - array& out, - const Shape& data_shape, - const Strides& strides_in_pre, - const Strides& strides_out_pre, - int64_t inp_offset, - int64_t out_offset, - CopyType ctype, - const Stream& s, - const std::optional& dynamic_i_offset /* = std::nullopt */, - const std::optional& dynamic_o_offset /* = std::nullopt */) { - throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend."); -} - -void fill_gpu(const array& val, array& out, const Stream& s) { - throw std::runtime_error("fill_gpu not implemented in CUDA backend."); -} - -} // namespace mlx::core diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu new file mode 100644 index 000000000..93b22f23e --- /dev/null +++ b/mlx/backend/cuda/copy.cu @@ -0,0 +1,126 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/cast_iterator.cuh" +#include "mlx/backend/cuda/iterators/general_iterator.cuh" +#include "mlx/backend/cuda/iterators/repeat_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +namespace mlx::core { + +void copy_gpu_inplace( + const array& in, + array& out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + int64_t inp_offset, + int64_t out_offset, + CopyType ctype, + const Stream& s, + const std::optional& dynamic_i_offset /* = std::nullopt */, + const std::optional& dynamic_o_offset /* = std::nullopt */) { + if (out.size() == 0) { + return; + } + // TODO: Figure out how to handle donated input. + assert(in.data_shared_ptr() != nullptr); + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + using InType = cuda_type_t; + using OutType = cuda_type_t; + if constexpr (cu::CastOp::is_castable) { + auto policy = cu::thrust_policy(stream); + auto in_ptr = cu::make_cast_iterator( + thrust::device_pointer_cast(in.data() + inp_offset)); + auto out_ptr = + thrust::device_pointer_cast(out.data() + out_offset); + if (ctype == CopyType::Scalar) { + thrust::copy_n( + policy, cu::repeat_iterator(in_ptr), out.data_size(), out_ptr); + } else if (ctype == CopyType::Vector) { + thrust::copy_n(policy, in_ptr, out.data_size(), out_ptr); + } else { + bool dynamic = dynamic_i_offset || dynamic_o_offset; + if (dynamic) { + throw std::runtime_error( + "Dynamic copy not implemented for CUDA backend."); + } + auto [shape_collapsed, strides_vec] = collapse_contiguous_dims( + shape, + std::vector{strides_in, strides_out}, + /* size_cap = */ INT32_MAX); + if (ctype == CopyType::General) { + thrust::copy_n( + policy, + cu::make_general_iterator( + in_ptr, shape_collapsed, strides_vec[0]), + out.data_size(), + out_ptr); + } else { + thrust::copy_n( + policy, + cu::make_general_iterator( + in_ptr, shape_collapsed, strides_vec[0]), + out.data_size(), + cu::make_general_iterator( + out_ptr, shape_collapsed, strides_vec[1])); + } + } + } else { + throw std::runtime_error(fmt::format( + "Can not copy data from dtype {} to {}.", + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); +} + +void fill_gpu(const array& val, array& out, const Stream& s) { + if (out.size() == 0) { + return; + } + out.set_data(allocator::malloc(out.nbytes())); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(val); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(val.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + using InType = cuda_type_t; + using OutType = cuda_type_t; + if constexpr (cu::CastOp::is_castable) { + thrust::copy_n( + cu::thrust_policy(stream), + cu::make_cast_iterator(cu::repeat_iterator( + thrust::device_pointer_cast(val.data()))), + out.data_size(), + thrust::device_pointer_cast(out.data())); + } else { + throw std::runtime_error(fmt::format( + "Can not fill data of dtype {} with {}", + dtype_to_string(out.dtype()), + dtype_to_string(val.dtype()))); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/iterators/cast_iterator.cuh b/mlx/backend/cuda/iterators/cast_iterator.cuh new file mode 100644 index 000000000..85cefb53b --- /dev/null +++ b/mlx/backend/cuda/iterators/cast_iterator.cuh @@ -0,0 +1,56 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::cu { + +template +struct CastOp { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ DstT operator()(SrcT x) { + return static_cast(x); + } +}; + +template +struct CastOp< + cuComplex, + DstT, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ DstT operator()(cuComplex x) { + static_assert(!cuda::std::is_same_v); + return static_cast(cuCrealf(x)); + } +}; + +template +struct CastOp< + SrcT, + cuComplex, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ cuComplex operator()(SrcT x) { + static_assert(!cuda::std::is_same_v); + return cuComplex{static_cast(x), 0}; + } +}; + +// Return an iterator that custom_cast the value to DstT. +template +__host__ __device__ auto make_cast_iterator(Iterator it) { + using SrcT = typename cuda::std::iterator_traits::value_type; + if constexpr (std::is_same_v) { + return it; + } else { + return thrust::make_transform_iterator(it, CastOp{}); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp index bfa742c74..af67fbbdd 100644 --- a/mlx/backend/cuda/slicing.cpp +++ b/mlx/backend/cuda/slicing.cpp @@ -1,7 +1,11 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" +#include + namespace mlx::core { void concatenate_gpu( @@ -9,7 +13,29 @@ void concatenate_gpu( array& out, int axis, const Stream& s) { - throw std::runtime_error("concatenate_gpu not implemented in CUDA backend."); + std::vector sizes; + sizes.push_back(0); + for (auto& p : inputs) { + sizes.push_back(p.shape(axis)); + } + std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); + + out.set_data(allocator::malloc(out.nbytes())); + + auto strides = out.strides(); + auto flags = out.flags(); + flags.row_contiguous = false; + flags.col_contiguous = false; + flags.contiguous = false; + // TODO: Handle concurrent outputs: + // https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816 + for (int i = 0; i < inputs.size(); i++) { + array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); + size_t data_offset = strides[axis] * sizes[i]; + out_slice.copy_shared_buffer( + out, strides, flags, out_slice.size(), data_offset); + copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s); + } } } // namespace mlx::core