rebase + nit (#2260)

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Cheng 2025-06-11 02:51:51 +09:00 committed by GitHub
parent 62fecf3e13
commit 99c33d011d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 604 additions and 28 deletions

View File

@ -7,7 +7,11 @@ target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
@ -28,6 +32,15 @@ target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
# Explicitly pass this flag to suppress the warning, it is safe to set it to
# true but the warning wouldn't be suppressed.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
target_compile_options(
mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES

View File

@ -1,26 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
void copy_gpu_inplace(
const array& in,
array& out,
const Shape& data_shape,
const Strides& strides_in_pre,
const Strides& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend.");
}
void fill_gpu(const array& val, array& out, const Stream& s) {
throw std::runtime_error("fill_gpu not implemented in CUDA backend.");
}
} // namespace mlx::core

89
mlx/backend/cuda/copy.cu Normal file
View File

@ -0,0 +1,89 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/copy/copy.cuh"
namespace mlx::core {
void copy_gpu_inplace(
const array& in_,
array& out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out,
int64_t offset_in,
int64_t offset_out,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_offset_in,
const std::optional<array>& dynamic_offset_out) {
if (out.size() == 0) {
return;
}
const array& in = in_.data_shared_ptr() ? in_ : out;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return;
}
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
shape, std::vector{strides_in, strides_out}, INT32_MAX);
if (ctype == CopyType::General) {
copy_general_input(
encoder,
ctype,
in,
out,
offset_in,
offset_out,
shape_collapsed,
strides_vec[0]);
} else {
if (dynamic_offset_in || dynamic_offset_out) {
copy_general_dynamic(
encoder,
ctype,
in,
out,
offset_in,
offset_out,
shape_collapsed,
strides_vec[0],
strides_vec[1],
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
} else {
copy_general(
encoder,
ctype,
in,
out,
offset_in,
offset_out,
shape_collapsed,
strides_vec[0],
strides_vec[1]);
}
}
return;
}
}
void fill_gpu(const array& in, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
}
} // namespace mlx::core

View File

@ -0,0 +1,71 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
namespace mlx::core {
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
__VA_ARGS__; \
} else { \
throw std::runtime_error(fmt::format( \
"Can not copy data from dtype {} to {}.", \
dtype_to_string(out.dtype()), \
dtype_to_string(in.dtype()))); \
} \
}); \
})
void copy_contiguous(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out);
void copy_general(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out);
void copy_general_dynamic(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out,
const array& dynamic_offset_in,
const array& dynamic_offset_out);
void copy_general_input(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in);
} // namespace mlx::core

View File

@ -0,0 +1,56 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/copy/copy.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT>
__global__ void copy_s(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = CastOp<In, Out>{}(in[0]);
}
}
template <typename In, typename Out, typename IdxT>
__global__ void copy_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = CastOp<In, Out>{}(in[index]);
}
}
} // namespace cu
void copy_contiguous(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t in_offset,
int64_t out_offset) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::copy_s<InType, OutType, IdxT>;
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
out.data_size());
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,95 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/copy/copy.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_gg_nd(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index, shape.data(), strides_in.data(), strides_out.data());
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
}
}
template <typename In, typename Out, typename IdxT>
__global__ void copy_gg(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides_out,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_4d(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
}
}
} // namespace cu
void copy_general(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,105 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/copy/copy.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_gg_dynamic_nd(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out,
const int64_t* offset_in,
const int64_t* offset_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index, shape.data(), strides_in.data(), strides_out.data());
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
}
}
template <typename In, typename Out, typename IdxT>
__global__ void copy_gg_dynamic(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides_out,
int ndim,
const int64_t* offset_in,
const int64_t* offset_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_4d(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
}
}
} // namespace cu
void copy_general_dynamic(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out,
const array& dynamic_offset_in,
const array& dynamic_offset_out) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,88 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/copy/copy.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_g_nd(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
out[index] = CastOp<In, Out>{}(in[idx_in]);
}
}
template <typename In, typename Out, typename IdxT>
__global__ void copy_g(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]);
}
}
} // namespace cu
void copy_general_input(
cu::CommandEncoder& encoder,
CopyType ctype,
const array& in,
array& out,
int64_t offset_in,
int64_t offset_out,
const Shape& shape,
const Strides& strides_in) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
const_param(shape),
const_param(strides_in),
ndim);
}
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,59 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuComplex.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu {
// An op that does static_cast, with custom conversions for some types.
template <typename SrcT, typename DstT, typename = void>
struct CastOp {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;
__device__ DstT operator()(SrcT x) {
return static_cast<DstT>(x);
}
};
// Converting a complex number to real number discards the imaginary part.
template <typename DstT>
struct CastOp<
cuComplex,
DstT,
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
__device__ DstT operator()(cuComplex x) {
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
return static_cast<DstT>(cuCrealf(x));
}
};
// Allow converting a real number to complex number.
template <typename SrcT>
struct CastOp<
SrcT,
cuComplex,
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
__device__ cuComplex operator()(SrcT x) {
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
return cuComplex{static_cast<float>(x), 0};
}
};
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;
} else {
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
}
}
} // namespace mlx::core::cu

View File

@ -1,7 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include <numeric>
namespace mlx::core {
void concatenate_gpu(
@ -9,7 +13,29 @@ void concatenate_gpu(
array& out,
int axis,
const Stream& s) {
throw std::runtime_error("concatenate_gpu not implemented in CUDA backend.");
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
// TODO: Handle concurrent outputs:
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);
}
}
} // namespace mlx::core