mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
e42e06046e
...
ca7970a4f1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca7970a4f1 | ||
|
|
214b1c1a06 |
@@ -247,7 +247,7 @@ void col_reduce_looped(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::ColReduceArgs args) {
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -271,7 +271,13 @@ void col_reduce_looped(
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
||||
kernel,
|
||||
grid,
|
||||
blocks,
|
||||
0,
|
||||
indata,
|
||||
out.data<U>(),
|
||||
static_cast<cu::ColReduceArgs>(args));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -284,7 +290,7 @@ void col_reduce_small(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::ColReduceArgs args) {
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -308,7 +314,7 @@ void col_reduce_small(
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
args,
|
||||
static_cast<cu::ColReduceArgs>(args),
|
||||
out.size());
|
||||
});
|
||||
});
|
||||
@@ -339,13 +345,12 @@ void col_reduce(
|
||||
// Small col reduce with a single or contiguous reduction axis
|
||||
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
|
||||
args.reduction_stride % (16 / in.itemsize()) == 0) {
|
||||
col_reduce_small(
|
||||
encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback col reduce
|
||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user