From b13c7ef8f8083a913429c7be3df485b6ddf0a482 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 15 Jun 2025 13:09:06 -0700 Subject: [PATCH 1/2] Fix some cuda back-end bugs and enable corresponding tests --- mlx/backend/cuda/binary.cu | 12 +++++++++++- mlx/backend/cuda/copy/copy.cuh | 9 +-------- mlx/backend/cuda/unary.cu | 2 +- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 47efc44d2..45ade0fda 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -264,7 +264,6 @@ BINARY_GPU(Add) BINARY_GPU(ArcTan2) BINARY_GPU(Divide) BINARY_GPU(Remainder) -BINARY_GPU(Equal) BINARY_GPU(Greater) BINARY_GPU(GreaterEqual) BINARY_GPU(Less) @@ -279,6 +278,17 @@ BINARY_GPU(NotEqual) BINARY_GPU(Power) BINARY_GPU(Subtract) +void Equal::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Equal::eval_gpu"); + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + if (equal_nan_) { + binary_op_gpu(inputs, out, op, s); + } else { + binary_op_gpu(inputs, out, op, s); + } +} + void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); auto& s = out.primitive().stream(); diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index 0c1eff774..ee5120274 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -15,14 +15,7 @@ namespace mlx::core { MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ using InType = cuda_type_t; \ using OutType = cuda_type_t; \ - if constexpr (cu::CastOp::is_castable) { \ - __VA_ARGS__; \ - } else { \ - throw std::runtime_error(fmt::format( \ - "Can not copy data from dtype {} to {}.", \ - dtype_to_string(out.dtype()), \ - dtype_to_string(in.dtype()))); \ - } \ + __VA_ARGS__; \ }); \ }) diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index f9d373455..f02898705 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -91,7 +91,7 @@ void unary_op_gpu_inplace( } else { auto [shape, strides] = collapse_contiguous_dims(in); auto [in_begin, in_end] = cu::make_general_iterators( - in_ptr, in.data_size(), shape, strides); + in_ptr, in.size(), shape, strides); thrust::transform(policy, in_begin, in_end, out_ptr, Op()); } } else { From 85869fda0c99df6f1de31cc8a4baa25afa4620e3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 15 Jun 2025 20:44:32 -0700 Subject: [PATCH 2/2] more fixes --- mlx/backend/cuda/binary.cu | 6 +++--- mlx/backend/cuda/copy.cu | 3 +-- mlx/backend/cuda/copy/copy_contiguous.cu | 3 ++- mlx/backend/cuda/copy/copy_general.cu | 8 ++++---- mlx/backend/cuda/copy/copy_general_dynamic.cu | 8 ++++---- mlx/backend/cuda/copy/copy_general_input.cu | 8 ++++---- mlx/backend/cuda/device/cast_op.cuh | 13 +++++++++++++ mlx/backend/cuda/kernel_utils.cuh | 17 ++++++++++++++--- mlx/backend/cuda/ternary.cu | 4 ++-- python/tests/cuda_skip.py | 3 --- 10 files changed, 47 insertions(+), 26 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 45ade0fda..0d2389de1 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -165,7 +165,7 @@ void binary_op_gpu_inplace( a.data(), b.data(), out.data(), - out.data_size(), + out.size(), const_param(shape), const_param(a_strides), const_param(b_strides)); @@ -178,7 +178,7 @@ void binary_op_gpu_inplace( a.data(), b.data(), out.data(), - out.data_size(), + out.size(), const_param(shape), const_param(a_strides), const_param(b_strides), @@ -197,7 +197,7 @@ void binary_op_gpu_inplace( kernel = cu::binary_vv; } auto [num_blocks, block_dims] = - get_launch_args(kernel, out, LARGE); + get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index 8649e1bf9..817860d0a 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -6,7 +6,7 @@ namespace mlx::core { void copy_gpu_inplace( - const array& in_, + const array& in, array& out, const Shape& shape, const Strides& strides_in, @@ -20,7 +20,6 @@ void copy_gpu_inplace( if (out.size() == 0) { return; } - const array& in = in_.data_shared_ptr() ? in_ : out; auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index fa79f0604..854fd93b4 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -43,7 +43,8 @@ void copy_contiguous( if (ctype == CopyType::Vector) { kernel = cu::copy_v; } - auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + auto [num_blocks, block_dims] = + get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( in.data() + in_offset, out.data() + out_offset, diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 3c5b3bbb3..9f50c8a31 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -59,9 +59,9 @@ void copy_general( MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { @@ -70,7 +70,7 @@ void copy_general( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out)); @@ -81,7 +81,7 @@ void copy_general( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index b9774662a..2e1cf4fba 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -65,9 +65,9 @@ void copy_general_dynamic( MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { @@ -76,7 +76,7 @@ void copy_general_dynamic( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), @@ -89,7 +89,7 @@ void copy_general_dynamic( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index 4f2784927..a3bb37e53 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -54,9 +54,9 @@ void copy_general_input( MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { @@ -65,7 +65,7 @@ void copy_general_input( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in)); }); @@ -75,7 +75,7 @@ void copy_general_input( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), ndim); diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh index 30b44d46f..115395db7 100644 --- a/mlx/backend/cuda/device/cast_op.cuh +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -45,6 +45,19 @@ struct CastOp< } }; +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = true; + + __device__ SrcT operator()(SrcT x) { + return x; + } +}; + + // Return an iterator that cast the value to DstT using CastOp. template __host__ __device__ auto make_cast_iterator(Iterator it) { diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 84392a1ec..05febaa5c 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -136,17 +136,19 @@ inline uint max_occupancy_block_dim(T kernel) { template inline std::tuple get_launch_args( T kernel, - const array& arr, + size_t size, + const Shape& shape, + const Strides& strides, bool large, int work_per_thread = 1) { - size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); + size_t nthreads = cuda::ceil_div(size, work_per_thread); uint block_dim = max_occupancy_block_dim(kernel); if (block_dim > nthreads) { block_dim = nthreads; } dim3 num_blocks; if (large) { - num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread); + num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); } else { num_blocks.x = cuda::ceil_div(nthreads, block_dim); @@ -154,4 +156,13 @@ inline std::tuple get_launch_args( return std::make_tuple(num_blocks, block_dim); } +template +inline std::tuple get_launch_args( + T kernel, + const array& arr, + bool large, + int work_per_thread = 1) { + return get_launch_args(kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index bb79d4249..41441ff40 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -116,7 +116,7 @@ void ternary_op_gpu_inplace( b.data(), c.data(), out.data(), - out.data_size(), + out.size(), const_param(shape), const_param(a_strides), const_param(b_strides), @@ -142,7 +142,7 @@ void ternary_op_gpu_inplace( MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { using IdxT = std::conditional_t; auto kernel = cu::ternary_v; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + auto [num_blocks, block_dims] = get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( a.data(), b.data(), diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index cda396dcb..ae3a32afc 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -100,7 +100,6 @@ cuda_skip = { "TestOps.test_hadamard", "TestOps.test_hadamard_grad_vmap", "TestOps.test_irregular_binary_ops", - "TestOps.test_isfinite", "TestOps.test_kron", "TestOps.test_log", "TestOps.test_log10", @@ -115,7 +114,6 @@ cuda_skip = { "TestOps.test_sort", "TestOps.test_tensordot", "TestOps.test_tile", - "TestOps.test_view", "TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", @@ -136,7 +134,6 @@ cuda_skip = { "TestReduce.test_expand_sums", "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", - "TestVmap.test_unary", "TestVmap.test_vmap_conv", "TestVmap.test_vmap_inverse", "TestVmap.test_vmap_svd",