diff --git a/mlx/backend/metal/kernels/reduction/reduce_row.h b/mlx/backend/metal/kernels/reduction/reduce_row.h index c8973429f..936d75bb5 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_row.h +++ b/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -224,7 +224,7 @@ template < if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. - IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) {