From 85a8824a8cd13d432a76551c176e43d5f10cbe07 Mon Sep 17 00:00:00 2001 From: AN Long Date: Thu, 9 Oct 2025 07:25:38 +0900 Subject: [PATCH] Fix cumulative operations when axis=None (#2653) --- mlx/ops.cpp | 41 +++++++++++++++++++++++++++++++++++++++++ mlx/ops.h | 35 +++++++++++++++++++++++++++++++++++ python/src/array.cpp | 24 +++++------------------- 3 files changed, 81 insertions(+), 19 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ffc7e4bb4..a65709752 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3468,6 +3468,14 @@ array cumsum( {a}); } +array cumsum( + const array& a, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + return cumsum(flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s)); +} + array cumprod( const array& a, int axis, @@ -3490,6 +3498,14 @@ array cumprod( {a}); } +array cumprod( + const array& a, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + return cumprod(flatten(a, s), 0, reverse, inclusive, s); +} + array cummax( const array& a, int axis, @@ -3512,6 +3528,14 @@ array cummax( {a}); } +array cummax( + const array& a, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + return cummax(flatten(a, s), 0, reverse, inclusive, s); +} + array cummin( const array& a, int axis, @@ -3534,6 +3558,14 @@ array cummin( {a}); } +array cummin( + const array& a, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + return cummin(flatten(a, s), 0, reverse, inclusive, s); +} + array logcumsumexp( const array& a, int axis, @@ -3556,6 +3588,15 @@ array logcumsumexp( {a}); } +array logcumsumexp( + const array& a, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + return logcumsumexp( + flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s)); +} + /** Convolution operations */ namespace { diff --git a/mlx/ops.h b/mlx/ops.h index 826f6d47b..cb3505df4 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -716,6 +716,13 @@ array topk(const array& a, int k, StreamOrDevice s = {}); array topk(const array& a, int k, int axis, StreamOrDevice s = {}); /** Cumulative logsumexp of an array. */ +array logcumsumexp( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative logsumexp of an array along the given axis. */ array logcumsumexp( const array& a, int axis, @@ -1186,6 +1193,13 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) { array power(const array& a, const array& b, StreamOrDevice s = {}); /** Cumulative sum of an array. */ +array cumsum( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative sum of an array along the given axis. */ array cumsum( const array& a, int axis, @@ -1194,6 +1208,13 @@ array cumsum( StreamOrDevice s = {}); /** Cumulative product of an array. */ +array cumprod( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative product of an array along the given axis. */ array cumprod( const array& a, int axis, @@ -1202,6 +1223,13 @@ array cumprod( StreamOrDevice s = {}); /** Cumulative max of an array. */ +array cummax( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative max of an array along the given axis. */ array cummax( const array& a, int axis, @@ -1210,6 +1238,13 @@ array cummax( StreamOrDevice s = {}); /** Cumulative min of an array. */ +array cummin( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative min of an array along the given axis. */ array cummin( const array& a, int axis, diff --git a/python/src/array.cpp b/python/src/array.cpp index 9367d4e09..231474c2d 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1275,10 +1275,7 @@ void init_array(nb::module_& m) { if (axis) { return mx::logcumsumexp(a, *axis, reverse, inclusive, s); } else { - // TODO: Implement that in the C++ API as well. See concatenate - // above. - return mx::logcumsumexp( - mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::logcumsumexp(a, reverse, inclusive, s); } }, "axis"_a = nb::none(), @@ -1408,9 +1405,7 @@ void init_array(nb::module_& m) { if (axis) { return mx::cumsum(a, *axis, reverse, inclusive, s); } else { - // TODO: Implement that in the C++ API as well. See concatenate - // above. - return mx::cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cumsum(a, reverse, inclusive, s); } }, "axis"_a = nb::none(), @@ -1429,10 +1424,7 @@ void init_array(nb::module_& m) { if (axis) { return mx::cumprod(a, *axis, reverse, inclusive, s); } else { - // TODO: Implement that in the C++ API as well. See concatenate - // above. - return mx::cumprod( - mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cumprod(a, reverse, inclusive, s); } }, "axis"_a = nb::none(), @@ -1451,10 +1443,7 @@ void init_array(nb::module_& m) { if (axis) { return mx::cummax(a, *axis, reverse, inclusive, s); } else { - // TODO: Implement that in the C++ API as well. See concatenate - // above. - return mx::cummax( - mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cummax(a, reverse, inclusive, s); } }, "axis"_a = nb::none(), @@ -1473,10 +1462,7 @@ void init_array(nb::module_& m) { if (axis) { return mx::cummin(a, *axis, reverse, inclusive, s); } else { - // TODO: Implement that in the C++ API as well. See concatenate - // above. - return mx::cummin( - mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cummin(a, reverse, inclusive, s); } }, "axis"_a = nb::none(),