Make args references but ensure copy to kernel

This commit is contained in:
Angelos Katharopoulos
2025-10-02 11:39:21 -07:00
parent 214b1c1a06
commit ca7970a4f1

View File

@@ -247,7 +247,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -271,7 +271,13 @@ void col_reduce_looped(
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel, grid, blocks, 0, indata, out.data<U>(), args);
kernel,
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
});
});
});
@@ -284,7 +290,7 @@ void col_reduce_small(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -308,7 +314,7 @@ void col_reduce_small(
0,
in.data<T>(),
out.data<U>(),
args,
static_cast<cu::ColReduceArgs>(args),
out.size());
});
});