Make args references but ensure copy to kernel

This commit is contained in:
Angelos Katharopoulos
2025-10-02 11:39:21 -07:00
parent 214b1c1a06
commit ca7970a4f1

View File

@@ -247,7 +247,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan, const ReductionPlan& plan,
cu::ColReduceArgs args) { const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as // Allocate data for the output using in's layout to access them as
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -271,7 +271,13 @@ void col_reduce_looped(
auto kernel = auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>; cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, blocks, 0, indata, out.data<U>(), args); kernel,
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
}); });
}); });
}); });
@@ -284,7 +290,7 @@ void col_reduce_small(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan, const ReductionPlan& plan,
cu::ColReduceArgs args) { const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as // Allocate data for the output using in's layout to access them as
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -308,7 +314,7 @@ void col_reduce_small(
0, 0,
in.data<T>(), in.data<T>(),
out.data<U>(), out.data<U>(),
args, static_cast<cu::ColReduceArgs>(args),
out.size()); out.size());
}); });
}); });