[CUDA] Fix back-end bugs and enable corresponding tests (#2296)

* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
This commit is contained in:
Awni Hannun
2025-06-16 08:45:40 -07:00
committed by GitHub
parent 4fda5fbdf9
commit c552ff2451
16 changed files with 115 additions and 98 deletions

View File

@@ -59,9 +59,9 @@ void copy_general(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@@ -70,7 +70,7 @@ void copy_general(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
@@ -81,7 +81,7 @@ void copy_general(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),