mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Make args references but ensure copy to kernel
This commit is contained in:
@@ -247,7 +247,7 @@ void col_reduce_looped(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::ColReduceArgs args) {
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -271,7 +271,13 @@ void col_reduce_looped(
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
||||
kernel,
|
||||
grid,
|
||||
blocks,
|
||||
0,
|
||||
indata,
|
||||
out.data<U>(),
|
||||
static_cast<cu::ColReduceArgs>(args));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -284,7 +290,7 @@ void col_reduce_small(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::ColReduceArgs args) {
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -308,7 +314,7 @@ void col_reduce_small(
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
args,
|
||||
static_cast<cu::ColReduceArgs>(args),
|
||||
out.size());
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user