From c353af5998a97862910e4307fa4635b2ce1df21f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 14 Jun 2025 06:16:09 -0700 Subject: [PATCH] fix cuda arg reduce --- mlx/backend/cuda/arg_reduce.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 7dbd91e46..c8a5a962a 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -1,5 +1,4 @@ // Copyright © 2025 Apple Inc. - #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/iterators/strided_iterator.cuh" @@ -113,7 +112,7 @@ __global__ void arg_reduce_general( for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { T vals[N_READS]; - auto tid = r * BLOCK_DIM + block.thread_index().z; + auto tid = r * BLOCK_DIM + block.thread_index().x; cub::LoadDirectBlocked( tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); best = op.reduce_many(best, vals, tid * N_READS); @@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { constexpr uint32_t N_READS = 4; MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims{1, 1, BLOCK_DIM}; + dim3 block_dims{BLOCK_DIM, 1, 1}; auto kernel = &cu::arg_reduce_general< InType, cu::ArgMax,