mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Fix reduce edge case (#1389)
This commit is contained in:

committed by
GitHub

parent
9592766939
commit
969337345f
@@ -32,7 +32,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1]) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user