mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CUDA backend: copy ops
This commit is contained in:
@@ -7,7 +7,7 @@ target_sources(
|
|||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
const Shape& data_shape,
|
|
||||||
const Strides& strides_in_pre,
|
|
||||||
const Strides& strides_out_pre,
|
|
||||||
int64_t inp_offset,
|
|
||||||
int64_t out_offset,
|
|
||||||
CopyType ctype,
|
|
||||||
const Stream& s,
|
|
||||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
|
||||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
|
||||||
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
|
||||||
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
126
mlx/backend/cuda/copy.cu
Normal file
126
mlx/backend/cuda/copy.cu
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/iterators/cast_iterator.cuh"
|
||||||
|
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||||
|
#include "mlx/backend/cuda/iterators/repeat_iterator.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <thrust/copy.h>
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out,
|
||||||
|
int64_t inp_offset,
|
||||||
|
int64_t out_offset,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s,
|
||||||
|
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
||||||
|
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// TODO: Figure out how to handle donated input.
|
||||||
|
assert(in.data_shared_ptr() != nullptr);
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
|
||||||
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||||
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
if constexpr (cu::CastOp<InType, OutType>::is_castable) {
|
||||||
|
auto policy = cu::thrust_policy(stream);
|
||||||
|
auto in_ptr = cu::make_cast_iterator<OutType>(
|
||||||
|
thrust::device_pointer_cast(in.data<InType>() + inp_offset));
|
||||||
|
auto out_ptr =
|
||||||
|
thrust::device_pointer_cast(out.data<OutType>() + out_offset);
|
||||||
|
if (ctype == CopyType::Scalar) {
|
||||||
|
thrust::copy_n(
|
||||||
|
policy, cu::repeat_iterator(in_ptr), out.data_size(), out_ptr);
|
||||||
|
} else if (ctype == CopyType::Vector) {
|
||||||
|
thrust::copy_n(policy, in_ptr, out.data_size(), out_ptr);
|
||||||
|
} else {
|
||||||
|
bool dynamic = dynamic_i_offset || dynamic_o_offset;
|
||||||
|
if (dynamic) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Dynamic copy not implemented for CUDA backend.");
|
||||||
|
}
|
||||||
|
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
|
||||||
|
shape,
|
||||||
|
std::vector{strides_in, strides_out},
|
||||||
|
/* size_cap = */ INT32_MAX);
|
||||||
|
if (ctype == CopyType::General) {
|
||||||
|
thrust::copy_n(
|
||||||
|
policy,
|
||||||
|
cu::make_general_iterator<int64_t>(
|
||||||
|
in_ptr, shape_collapsed, strides_vec[0]),
|
||||||
|
out.data_size(),
|
||||||
|
out_ptr);
|
||||||
|
} else {
|
||||||
|
thrust::copy_n(
|
||||||
|
policy,
|
||||||
|
cu::make_general_iterator<int64_t>(
|
||||||
|
in_ptr, shape_collapsed, strides_vec[0]),
|
||||||
|
out.data_size(),
|
||||||
|
cu::make_general_iterator<int64_t>(
|
||||||
|
out_ptr, shape_collapsed, strides_vec[1]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Can not copy data from dtype {} to {}.",
|
||||||
|
dtype_to_string(in.dtype()),
|
||||||
|
dtype_to_string(out.dtype())));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(val);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(val.dtype(), CTYPE_IN, {
|
||||||
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
||||||
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
if constexpr (cu::CastOp<InType, OutType>::is_castable) {
|
||||||
|
thrust::copy_n(
|
||||||
|
cu::thrust_policy(stream),
|
||||||
|
cu::make_cast_iterator<OutType>(cu::repeat_iterator(
|
||||||
|
thrust::device_pointer_cast(val.data<InType>()))),
|
||||||
|
out.data_size(),
|
||||||
|
thrust::device_pointer_cast(out.data<OutType>()));
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Can not fill data of dtype {} with {}",
|
||||||
|
dtype_to_string(out.dtype()),
|
||||||
|
dtype_to_string(val.dtype())));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
56
mlx/backend/cuda/iterators/cast_iterator.cuh
Normal file
56
mlx/backend/cuda/iterators/cast_iterator.cuh
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuComplex.h>
|
||||||
|
#include <thrust/iterator/transform_iterator.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename = void>
|
||||||
|
struct CastOp {
|
||||||
|
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;
|
||||||
|
|
||||||
|
__device__ DstT operator()(SrcT x) {
|
||||||
|
return static_cast<DstT>(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DstT>
|
||||||
|
struct CastOp<
|
||||||
|
cuComplex,
|
||||||
|
DstT,
|
||||||
|
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
|
||||||
|
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
|
||||||
|
|
||||||
|
__device__ DstT operator()(cuComplex x) {
|
||||||
|
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
|
||||||
|
return static_cast<DstT>(cuCrealf(x));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SrcT>
|
||||||
|
struct CastOp<
|
||||||
|
SrcT,
|
||||||
|
cuComplex,
|
||||||
|
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
|
||||||
|
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
|
||||||
|
|
||||||
|
__device__ cuComplex operator()(SrcT x) {
|
||||||
|
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
|
||||||
|
return cuComplex{static_cast<float>(x), 0};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return an iterator that custom_cast the value to DstT.
|
||||||
|
template <typename DstT, typename Iterator>
|
||||||
|
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||||
|
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
||||||
|
if constexpr (std::is_same_v<SrcT, DstT>) {
|
||||||
|
return it;
|
||||||
|
} else {
|
||||||
|
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -1,7 +1,11 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/slicing.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/slicing.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void concatenate_gpu(
|
void concatenate_gpu(
|
||||||
@@ -9,7 +13,29 @@ void concatenate_gpu(
|
|||||||
array& out,
|
array& out,
|
||||||
int axis,
|
int axis,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
|
std::vector<int> sizes;
|
||||||
|
sizes.push_back(0);
|
||||||
|
for (auto& p : inputs) {
|
||||||
|
sizes.push_back(p.shape(axis));
|
||||||
|
}
|
||||||
|
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto strides = out.strides();
|
||||||
|
auto flags = out.flags();
|
||||||
|
flags.row_contiguous = false;
|
||||||
|
flags.col_contiguous = false;
|
||||||
|
flags.contiguous = false;
|
||||||
|
// TODO: Handle concurrent outputs:
|
||||||
|
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
|
||||||
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
|
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||||
|
size_t data_offset = strides[axis] * sizes[i];
|
||||||
|
out_slice.copy_shared_buffer(
|
||||||
|
out, strides, flags, out_slice.size(), data_offset);
|
||||||
|
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
Reference in New Issue
Block a user