Refactor reductions and fix scatter atomics for large sizes (#1300)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-08-22 16:03:31 -07:00
committed by GitHub
parent f9e00efe31
commit 98b6ce3460
18 changed files with 1584 additions and 1235 deletions

View File

@@ -49,7 +49,7 @@ struct ReductionPlan {
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Helper for the ndimensional strided loop
// Should this be in utils?

View File

@@ -19,7 +19,7 @@ std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
return std::make_pair(shape, strides);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
@@ -41,6 +41,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
}
}
// Remove singleton axes from the plan
for (int i = shape.size() - 1; i >= 0; i--) {
if (shape[i] == 1) {
shape.erase(shape.begin() + i);
strides.erase(strides.begin() + i);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
@@ -63,10 +71,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
return a.second > b.second;
bool a_is_zero = a.second == 0;
bool b_is_zero = b.second == 0;
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
@@ -98,16 +110,33 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
if (x.strides()[i] != size) {
size_t stride_i = x.strides()[i];
int shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;
}
have_expand = true;
break;
}
size *= x.shape(i);
if (stride_i != size && shape_i != 1) {
break;
}
size *= shape_i;
}
if (size >= strides.back()) {
// In the case of an expanded dimension we are being conservative and
// require the smallest reduction stride to be smaller than the maximum row
// contiguous size. The reason is that we can't easily know if the reduced
// axis is before or after an expanded dimension.
if (size > strides.back() || (size == strides.back() && !have_expand)) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}

View File

@@ -104,6 +104,33 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
std::vector<array>{std::forward<Arrays>(xs)...});
}
// The single array version of the above.
inline std::tuple<std::vector<int>, std::vector<size_t>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::vector<int> collapsed_shape;
std::vector<size_t> collapsed_strides;
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
collapsed_strides.push_back(strides[0]);
for (int i = 1; i < shape.size(); i++) {
if (strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
std::numeric_limits<int>::max()) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
collapsed_shape.back() *= shape[i];
collapsed_strides.back() = strides[i];
}
}
}
return std::make_tuple(collapsed_shape, collapsed_strides);
}
template <typename stride_t>
inline auto check_contiguity(
const std::vector<int>& shape,