mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Make args references but ensure copy to kernel
This commit is contained in:
@@ -247,7 +247,7 @@ void col_reduce_looped(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
cu::ColReduceArgs args) {
|
const cu::ColReduceArgs& args) {
|
||||||
// Allocate data for the output using in's layout to access them as
|
// Allocate data for the output using in's layout to access them as
|
||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
@@ -271,7 +271,13 @@ void col_reduce_looped(
|
|||||||
auto kernel =
|
auto kernel =
|
||||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
kernel,
|
||||||
|
grid,
|
||||||
|
blocks,
|
||||||
|
0,
|
||||||
|
indata,
|
||||||
|
out.data<U>(),
|
||||||
|
static_cast<cu::ColReduceArgs>(args));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -284,7 +290,7 @@ void col_reduce_small(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
cu::ColReduceArgs args) {
|
const cu::ColReduceArgs& args) {
|
||||||
// Allocate data for the output using in's layout to access them as
|
// Allocate data for the output using in's layout to access them as
|
||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
@@ -308,7 +314,7 @@ void col_reduce_small(
|
|||||||
0,
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
args,
|
static_cast<cu::ColReduceArgs>(args),
|
||||||
out.size());
|
out.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user