From ca7970a4f115ae3bdcdab2ddb01c6e5b316a070d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 2 Oct 2025 11:39:21 -0700 Subject: [PATCH] Make args references but ensure copy to kernel --- mlx/backend/cuda/reduce/col_reduce.cu | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 6b18d1d0a..feafc2fb0 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -247,7 +247,7 @@ void col_reduce_looped( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan, - cu::ColReduceArgs args) { + const cu::ColReduceArgs& args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. allocate_same_layout(out, in, axes); @@ -271,7 +271,13 @@ void col_reduce_looped( auto kernel = cu::col_reduce_looped; encoder.add_kernel_node( - kernel, grid, blocks, 0, indata, out.data(), args); + kernel, + grid, + blocks, + 0, + indata, + out.data(), + static_cast(args)); }); }); }); @@ -284,7 +290,7 @@ void col_reduce_small( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan, - cu::ColReduceArgs args) { + const cu::ColReduceArgs& args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. allocate_same_layout(out, in, axes); @@ -308,7 +314,7 @@ void col_reduce_small( 0, in.data(), out.data(), - args, + static_cast(args), out.size()); }); });