Refactor the reduction kernels (#277)

This commit is contained in:
Angelos Katharopoulos
2023-12-24 14:47:57 -08:00
committed by GitHub
parent 22fee5a383
commit 9e6b8c9f48
4 changed files with 179 additions and 369 deletions

View File

@@ -126,7 +126,7 @@ struct ReductionPlan {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
(x.flags().row_contiguous || x.flags().col_contiguous)) {
x.flags().contiguous) {
return ContiguousAllReduce;
}