mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -49,7 +49,7 @@ struct ReductionPlan {
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
|
||||
@@ -19,7 +19,7 @@ std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
@@ -41,6 +41,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
}
|
||||
}
|
||||
|
||||
// Remove singleton axes from the plan
|
||||
for (int i = shape.size() - 1; i >= 0; i--) {
|
||||
if (shape[i] == 1) {
|
||||
shape.erase(shape.begin() + i);
|
||||
strides.erase(strides.begin() + i);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
@@ -63,10 +71,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
return a.second > b.second;
|
||||
bool a_is_zero = a.second == 0;
|
||||
bool b_is_zero = b.second == 0;
|
||||
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
@@ -98,16 +110,33 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
if (x.strides()[i] != size) {
|
||||
|
||||
size_t stride_i = x.strides()[i];
|
||||
int shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
have_expand = true;
|
||||
break;
|
||||
}
|
||||
size *= x.shape(i);
|
||||
|
||||
if (stride_i != size && shape_i != 1) {
|
||||
break;
|
||||
}
|
||||
size *= shape_i;
|
||||
}
|
||||
if (size >= strides.back()) {
|
||||
// In the case of an expanded dimension we are being conservative and
|
||||
// require the smallest reduction stride to be smaller than the maximum row
|
||||
// contiguous size. The reason is that we can't easily know if the reduced
|
||||
// axis is before or after an expanded dimension.
|
||||
if (size > strides.back() || (size == strides.back() && !have_expand)) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,33 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
// The single array version of the above.
|
||||
inline std::tuple<std::vector<int>, std::vector<size_t>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::vector<int> collapsed_shape;
|
||||
std::vector<size_t> collapsed_strides;
|
||||
|
||||
if (shape.size() > 0) {
|
||||
collapsed_shape.push_back(shape[0]);
|
||||
collapsed_strides.push_back(strides[0]);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
if (strides[i] * shape[i] != collapsed_strides.back() ||
|
||||
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
|
||||
std::numeric_limits<int>::max()) {
|
||||
collapsed_shape.push_back(shape[i]);
|
||||
collapsed_strides.push_back(strides[i]);
|
||||
} else {
|
||||
collapsed_shape.back() *= shape[i];
|
||||
collapsed_strides.back() = strides[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(collapsed_shape, collapsed_strides);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
|
||||
Reference in New Issue
Block a user