mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
e42e06046e
...
ca7970a4f1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca7970a4f1 | ||
|
|
214b1c1a06 |
@@ -247,7 +247,7 @@ void col_reduce_looped(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
cu::ColReduceArgs args) {
|
const cu::ColReduceArgs& args) {
|
||||||
// Allocate data for the output using in's layout to access them as
|
// Allocate data for the output using in's layout to access them as
|
||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
@@ -271,7 +271,13 @@ void col_reduce_looped(
|
|||||||
auto kernel =
|
auto kernel =
|
||||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
kernel,
|
||||||
|
grid,
|
||||||
|
blocks,
|
||||||
|
0,
|
||||||
|
indata,
|
||||||
|
out.data<U>(),
|
||||||
|
static_cast<cu::ColReduceArgs>(args));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -284,7 +290,7 @@ void col_reduce_small(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
cu::ColReduceArgs args) {
|
const cu::ColReduceArgs& args) {
|
||||||
// Allocate data for the output using in's layout to access them as
|
// Allocate data for the output using in's layout to access them as
|
||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
@@ -308,7 +314,7 @@ void col_reduce_small(
|
|||||||
0,
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
args,
|
static_cast<cu::ColReduceArgs>(args),
|
||||||
out.size());
|
out.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -339,13 +345,12 @@ void col_reduce(
|
|||||||
// Small col reduce with a single or contiguous reduction axis
|
// Small col reduce with a single or contiguous reduction axis
|
||||||
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
|
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
|
||||||
args.reduction_stride % (16 / in.itemsize()) == 0) {
|
args.reduction_stride % (16 / in.itemsize()) == 0) {
|
||||||
col_reduce_small(
|
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
encoder, in, out, reduce_type, axes, plan, std::move(args));
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback col reduce
|
// Fallback col reduce
|
||||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
Reference in New Issue
Block a user