From aa7841814cc060e9e75328ef97ba6fe2ed11e5e5 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 29 Jun 2025 04:11:09 -0700 Subject: [PATCH] Change name of argument --- mlx/backend/cuda/arg_reduce.cu | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 727cb027d..90f8561c1 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -156,22 +156,15 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { using T = cuda_type_t; constexpr uint32_t N_READS = 4; dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim_constant) { + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims{block_dim_constant(), 1, 1}; - auto kernel = cu::arg_reduce_general< - T, - cu::ArgMax, - block_dim_constant(), - N_READS>; + auto kernel = + cu::arg_reduce_general, block_dim(), N_READS>; if (reduce_type_ == ArgReduce::ArgMin) { - kernel = cu::arg_reduce_general< - T, - cu::ArgMin, - block_dim_constant(), - N_READS>; + kernel = cu:: + arg_reduce_general, block_dim(), N_READS>; } - kernel<<>>( + kernel<<>>( in.data(), out.data(), out.size(),