diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index b5db056d3..fc5b8c496 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -247,7 +247,7 @@ void binary_op_gpu_inplace( } }); } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 93fb1da59..4b6e24581 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -269,7 +269,7 @@ void binary_two_op_gpu_inplace( } }); } else { - dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 2122ba497..eb69442c2 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -164,7 +164,7 @@ void ternary_op_gpu_inplace( } }); } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 6a5fcce11..1fe1b557b 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -127,8 +127,8 @@ void unary_op_gpu_inplace( dispatch_bool(large, [&](auto large) { using InType = cuda_type_t; using OutType = cuda_type_t; - using IdxT = std::conditional_t; if (contig) { + using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; auto kernel = cu::unary_v; @@ -147,6 +147,7 @@ void unary_op_gpu_inplace( out.data(), out.data_size()); } else { + using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); auto kernel = cu::unary_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);