mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
Fix cumulative operations when axis=None (#2653)
This commit is contained in:
@@ -1275,10 +1275,7 @@ void init_array(nb::module_& m) {
|
||||
if (axis) {
|
||||
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
|
||||
} else {
|
||||
// TODO: Implement that in the C++ API as well. See concatenate
|
||||
// above.
|
||||
return mx::logcumsumexp(
|
||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||
return mx::logcumsumexp(a, reverse, inclusive, s);
|
||||
}
|
||||
},
|
||||
"axis"_a = nb::none(),
|
||||
@@ -1408,9 +1405,7 @@ void init_array(nb::module_& m) {
|
||||
if (axis) {
|
||||
return mx::cumsum(a, *axis, reverse, inclusive, s);
|
||||
} else {
|
||||
// TODO: Implement that in the C++ API as well. See concatenate
|
||||
// above.
|
||||
return mx::cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||
return mx::cumsum(a, reverse, inclusive, s);
|
||||
}
|
||||
},
|
||||
"axis"_a = nb::none(),
|
||||
@@ -1429,10 +1424,7 @@ void init_array(nb::module_& m) {
|
||||
if (axis) {
|
||||
return mx::cumprod(a, *axis, reverse, inclusive, s);
|
||||
} else {
|
||||
// TODO: Implement that in the C++ API as well. See concatenate
|
||||
// above.
|
||||
return mx::cumprod(
|
||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||
return mx::cumprod(a, reverse, inclusive, s);
|
||||
}
|
||||
},
|
||||
"axis"_a = nb::none(),
|
||||
@@ -1451,10 +1443,7 @@ void init_array(nb::module_& m) {
|
||||
if (axis) {
|
||||
return mx::cummax(a, *axis, reverse, inclusive, s);
|
||||
} else {
|
||||
// TODO: Implement that in the C++ API as well. See concatenate
|
||||
// above.
|
||||
return mx::cummax(
|
||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||
return mx::cummax(a, reverse, inclusive, s);
|
||||
}
|
||||
},
|
||||
"axis"_a = nb::none(),
|
||||
@@ -1473,10 +1462,7 @@ void init_array(nb::module_& m) {
|
||||
if (axis) {
|
||||
return mx::cummin(a, *axis, reverse, inclusive, s);
|
||||
} else {
|
||||
// TODO: Implement that in the C++ API as well. See concatenate
|
||||
// above.
|
||||
return mx::cummin(
|
||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||
return mx::cummin(a, reverse, inclusive, s);
|
||||
}
|
||||
},
|
||||
"axis"_a = nb::none(),
|
||||
|
Reference in New Issue
Block a user