Fix reduce edge case (#1389)

This commit is contained in:
Angelos Katharopoulos
2024-09-01 21:37:51 -07:00
committed by GitHub
parent 9592766939
commit 969337345f
2 changed files with 16 additions and 16 deletions

View File

@@ -32,7 +32,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1]) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {