Compare commits

...

2 Commits

Author SHA1 Message Date
Angelos Katharopoulos
ca7970a4f1 Make args references but ensure copy to kernel 2025-10-02 11:39:21 -07:00
Angelos Katharopoulos
214b1c1a06 Remove moves 2025-10-02 11:16:17 -07:00

View File

@@ -247,7 +247,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -271,7 +271,13 @@ void col_reduce_looped(
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel, grid, blocks, 0, indata, out.data<U>(), args);
kernel,
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
});
});
});
@@ -284,7 +290,7 @@ void col_reduce_small(
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::ColReduceArgs args) {
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
@@ -308,7 +314,7 @@ void col_reduce_small(
0,
in.data<T>(),
out.data<U>(),
args,
static_cast<cu::ColReduceArgs>(args),
out.size());
});
});
@@ -339,13 +345,12 @@ void col_reduce(
// Small col reduce with a single or contiguous reduction axis
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
args.reduction_stride % (16 / in.itemsize()) == 0) {
col_reduce_small(
encoder, in, out, reduce_type, axes, plan, std::move(args));
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}
} // namespace mlx::core