From e3534c2db831d0e796ef3fb4c6201ddfba6edc99 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 9 Jul 2025 23:05:21 +0000 Subject: [PATCH] Contig uses uint as index and non-contig uses int --- mlx/backend/cuda/binary.cu | 2 +- mlx/backend/cuda/binary_two.cu | 2 +- mlx/backend/cuda/ternary.cu | 2 +- mlx/backend/cuda/unary.cu | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index b5db056d3..fc5b8c496 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -247,7 +247,7 @@ void binary_op_gpu_inplace( } }); } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 93fb1da59..4b6e24581 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -269,7 +269,7 @@ void binary_two_op_gpu_inplace( } }); } else { - dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 2122ba497..eb69442c2 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -164,7 +164,7 @@ void ternary_op_gpu_inplace( } }); } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 6a5fcce11..1fe1b557b 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -127,8 +127,8 @@ void unary_op_gpu_inplace( dispatch_bool(large, [&](auto large) { using InType = cuda_type_t; using OutType = cuda_type_t; - using IdxT = std::conditional_t; if (contig) { + using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; auto kernel = cu::unary_v; @@ -147,6 +147,7 @@ void unary_op_gpu_inplace( out.data(), out.data_size()); } else { + using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); auto kernel = cu::unary_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);