Compare commits

...

2 Commits

Author SHA1 Message Date
Angelos Katharopoulos
ca7970a4f1 Make args references but ensure copy to kernel 2025-10-02 11:39:21 -07:00
Angelos Katharopoulos
214b1c1a06 Remove moves 2025-10-02 11:16:17 -07:00

View File

@@ -247,7 +247,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan, const ReductionPlan& plan,
cu::ColReduceArgs args) { const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as // Allocate data for the output using in's layout to access them as
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -271,7 +271,13 @@ void col_reduce_looped(
auto kernel = auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>; cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, blocks, 0, indata, out.data<U>(), args); kernel,
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
}); });
}); });
}); });
@@ -284,7 +290,7 @@ void col_reduce_small(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan, const ReductionPlan& plan,
cu::ColReduceArgs args) { const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as // Allocate data for the output using in's layout to access them as
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -308,7 +314,7 @@ void col_reduce_small(
0, 0,
in.data<T>(), in.data<T>(),
out.data<U>(), out.data<U>(),
args, static_cast<cu::ColReduceArgs>(args),
out.size()); out.size());
}); });
}); });
@@ -339,13 +345,12 @@ void col_reduce(
// Small col reduce with a single or contiguous reduction axis // Small col reduce with a single or contiguous reduction axis
if (args.non_col_reductions == 1 && args.reduction_size <= 32 && if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
args.reduction_stride % (16 / in.itemsize()) == 0) { args.reduction_stride % (16 / in.itemsize()) == 0) {
col_reduce_small( col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
encoder, in, out, reduce_type, axes, plan, std::move(args));
return; return;
} }
// Fallback col reduce // Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
} }
} // namespace mlx::core } // namespace mlx::core