CUDA backend: copy ops

This commit is contained in:
Cheng
2025-04-14 00:20:19 +00:00
parent 194212f65f
commit 2232084f58
5 changed files with 210 additions and 28 deletions

View File

@@ -7,7 +7,7 @@ target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu

View File

@@ -1,26 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& data_shape,
const Strides& strides_in_pre,
const Strides& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
}
void fill_gpu(const array& val, array& out, const Stream& s) {
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

126
mlx/backend/cuda/copy.cu Normal file
View File

@@ -0,0 +1,126 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/cast_iterator.cuh"
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/iterators/repeat_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <cassert>
namespace mlx::core {
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
if (out.size() == 0) {
return;
}
// TODO: Figure out how to handle donated input.
assert(in.data_shared_ptr() != nullptr);
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
if constexpr (cu::CastOp<InType, OutType>::is_castable) {
auto policy = cu::thrust_policy(stream);
auto in_ptr = cu::make_cast_iterator<OutType>(
thrust::device_pointer_cast(in.data<InType>() + inp_offset));
auto out_ptr =
thrust::device_pointer_cast(out.data<OutType>() + out_offset);
if (ctype == CopyType::Scalar) {
thrust::copy_n(
policy, cu::repeat_iterator(in_ptr), out.data_size(), out_ptr);
} else if (ctype == CopyType::Vector) {
thrust::copy_n(policy, in_ptr, out.data_size(), out_ptr);
} else {
bool dynamic = dynamic_i_offset || dynamic_o_offset;
if (dynamic) {
throw std::runtime_error(
"Dynamic copy not implemented for CUDA backend.");
}
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
shape,
std::vector{strides_in, strides_out},
/* size_cap = */ INT32_MAX);
if (ctype == CopyType::General) {
thrust::copy_n(
policy,
cu::make_general_iterator<int64_t>(
in_ptr, shape_collapsed, strides_vec[0]),
out.data_size(),
out_ptr);
} else {
thrust::copy_n(
policy,
cu::make_general_iterator<int64_t>(
in_ptr, shape_collapsed, strides_vec[0]),
out.data_size(),
cu::make_general_iterator<int64_t>(
out_ptr, shape_collapsed, strides_vec[1]));
}
}
} else {
throw std::runtime_error(fmt::format(
"Can not copy data from dtype {} to {}.",
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(val);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(val.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
if constexpr (cu::CastOp<InType, OutType>::is_castable) {
thrust::copy_n(
cu::thrust_policy(stream),
cu::make_cast_iterator<OutType>(cu::repeat_iterator(
thrust::device_pointer_cast(val.data<InType>()))),
out.data_size(),
thrust::device_pointer_cast(out.data<OutType>()));
} else {
throw std::runtime_error(fmt::format(
"Can not fill data of dtype {} with {}",
dtype_to_string(out.dtype()),
dtype_to_string(val.dtype())));
}
});
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,56 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuComplex.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu {
template <typename SrcT, typename DstT, typename = void>
struct CastOp {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;
__device__ DstT operator()(SrcT x) {
return static_cast<DstT>(x);
}
};
template <typename DstT>
struct CastOp<
cuComplex,
DstT,
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
__device__ DstT operator()(cuComplex x) {
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
return static_cast<DstT>(cuCrealf(x));
}
};
template <typename SrcT>
struct CastOp<
SrcT,
cuComplex,
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
__device__ cuComplex operator()(SrcT x) {
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
return cuComplex{static_cast<float>(x), 0};
}
};
// Return an iterator that custom_cast the value to DstT.
template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;
} else {
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
}
}
} // namespace mlx::core::cu

View File

@@ -1,7 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include <numeric>
namespace mlx::core {
void concatenate_gpu(
@@ -9,7 +13,29 @@ void concatenate_gpu(
array& out,
int axis,
const Stream& s) {
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
// TODO: Handle concurrent outputs:
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);
}
}
} // namespace mlx::core