mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Fix cumulative operations when axis=None (#2653)
This commit is contained in:
41
mlx/ops.cpp
41
mlx/ops.cpp
@@ -3468,6 +3468,14 @@ array cumsum(
|
|||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array cumsum(
|
||||||
|
const array& a,
|
||||||
|
bool reverse /* = false*/,
|
||||||
|
bool inclusive /* = true*/,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
return cumsum(flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
array cumprod(
|
array cumprod(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
@@ -3490,6 +3498,14 @@ array cumprod(
|
|||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array cumprod(
|
||||||
|
const array& a,
|
||||||
|
bool reverse /* = false*/,
|
||||||
|
bool inclusive /* = true*/,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
return cumprod(flatten(a, s), 0, reverse, inclusive, s);
|
||||||
|
}
|
||||||
|
|
||||||
array cummax(
|
array cummax(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
@@ -3512,6 +3528,14 @@ array cummax(
|
|||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array cummax(
|
||||||
|
const array& a,
|
||||||
|
bool reverse /* = false*/,
|
||||||
|
bool inclusive /* = true*/,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
return cummax(flatten(a, s), 0, reverse, inclusive, s);
|
||||||
|
}
|
||||||
|
|
||||||
array cummin(
|
array cummin(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
@@ -3534,6 +3558,14 @@ array cummin(
|
|||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array cummin(
|
||||||
|
const array& a,
|
||||||
|
bool reverse /* = false*/,
|
||||||
|
bool inclusive /* = true*/,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
return cummin(flatten(a, s), 0, reverse, inclusive, s);
|
||||||
|
}
|
||||||
|
|
||||||
array logcumsumexp(
|
array logcumsumexp(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
@@ -3556,6 +3588,15 @@ array logcumsumexp(
|
|||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array logcumsumexp(
|
||||||
|
const array& a,
|
||||||
|
bool reverse /* = false*/,
|
||||||
|
bool inclusive /* = true*/,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
return logcumsumexp(
|
||||||
|
flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
/** Convolution operations */
|
/** Convolution operations */
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
35
mlx/ops.h
35
mlx/ops.h
@@ -716,6 +716,13 @@ array topk(const array& a, int k, StreamOrDevice s = {});
|
|||||||
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Cumulative logsumexp of an array. */
|
/** Cumulative logsumexp of an array. */
|
||||||
|
array logcumsumexp(
|
||||||
|
const array& a,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cumulative logsumexp of an array along the given axis. */
|
||||||
array logcumsumexp(
|
array logcumsumexp(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
@@ -1186,6 +1193,13 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
|
|||||||
array power(const array& a, const array& b, StreamOrDevice s = {});
|
array power(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Cumulative sum of an array. */
|
/** Cumulative sum of an array. */
|
||||||
|
array cumsum(
|
||||||
|
const array& a,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cumulative sum of an array along the given axis. */
|
||||||
array cumsum(
|
array cumsum(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
@@ -1194,6 +1208,13 @@ array cumsum(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Cumulative product of an array. */
|
/** Cumulative product of an array. */
|
||||||
|
array cumprod(
|
||||||
|
const array& a,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cumulative product of an array along the given axis. */
|
||||||
array cumprod(
|
array cumprod(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
@@ -1202,6 +1223,13 @@ array cumprod(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Cumulative max of an array. */
|
/** Cumulative max of an array. */
|
||||||
|
array cummax(
|
||||||
|
const array& a,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cumulative max of an array along the given axis. */
|
||||||
array cummax(
|
array cummax(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
@@ -1210,6 +1238,13 @@ array cummax(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Cumulative min of an array. */
|
/** Cumulative min of an array. */
|
||||||
|
array cummin(
|
||||||
|
const array& a,
|
||||||
|
bool reverse = false,
|
||||||
|
bool inclusive = true,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Cumulative min of an array along the given axis. */
|
||||||
array cummin(
|
array cummin(
|
||||||
const array& a,
|
const array& a,
|
||||||
int axis,
|
int axis,
|
||||||
|
@@ -1275,10 +1275,7 @@ void init_array(nb::module_& m) {
|
|||||||
if (axis) {
|
if (axis) {
|
||||||
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
|
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
|
||||||
} else {
|
} else {
|
||||||
// TODO: Implement that in the C++ API as well. See concatenate
|
return mx::logcumsumexp(a, reverse, inclusive, s);
|
||||||
// above.
|
|
||||||
return mx::logcumsumexp(
|
|
||||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"axis"_a = nb::none(),
|
"axis"_a = nb::none(),
|
||||||
@@ -1408,9 +1405,7 @@ void init_array(nb::module_& m) {
|
|||||||
if (axis) {
|
if (axis) {
|
||||||
return mx::cumsum(a, *axis, reverse, inclusive, s);
|
return mx::cumsum(a, *axis, reverse, inclusive, s);
|
||||||
} else {
|
} else {
|
||||||
// TODO: Implement that in the C++ API as well. See concatenate
|
return mx::cumsum(a, reverse, inclusive, s);
|
||||||
// above.
|
|
||||||
return mx::cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"axis"_a = nb::none(),
|
"axis"_a = nb::none(),
|
||||||
@@ -1429,10 +1424,7 @@ void init_array(nb::module_& m) {
|
|||||||
if (axis) {
|
if (axis) {
|
||||||
return mx::cumprod(a, *axis, reverse, inclusive, s);
|
return mx::cumprod(a, *axis, reverse, inclusive, s);
|
||||||
} else {
|
} else {
|
||||||
// TODO: Implement that in the C++ API as well. See concatenate
|
return mx::cumprod(a, reverse, inclusive, s);
|
||||||
// above.
|
|
||||||
return mx::cumprod(
|
|
||||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"axis"_a = nb::none(),
|
"axis"_a = nb::none(),
|
||||||
@@ -1451,10 +1443,7 @@ void init_array(nb::module_& m) {
|
|||||||
if (axis) {
|
if (axis) {
|
||||||
return mx::cummax(a, *axis, reverse, inclusive, s);
|
return mx::cummax(a, *axis, reverse, inclusive, s);
|
||||||
} else {
|
} else {
|
||||||
// TODO: Implement that in the C++ API as well. See concatenate
|
return mx::cummax(a, reverse, inclusive, s);
|
||||||
// above.
|
|
||||||
return mx::cummax(
|
|
||||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"axis"_a = nb::none(),
|
"axis"_a = nb::none(),
|
||||||
@@ -1473,10 +1462,7 @@ void init_array(nb::module_& m) {
|
|||||||
if (axis) {
|
if (axis) {
|
||||||
return mx::cummin(a, *axis, reverse, inclusive, s);
|
return mx::cummin(a, *axis, reverse, inclusive, s);
|
||||||
} else {
|
} else {
|
||||||
// TODO: Implement that in the C++ API as well. See concatenate
|
return mx::cummin(a, reverse, inclusive, s);
|
||||||
// above.
|
|
||||||
return mx::cummin(
|
|
||||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"axis"_a = nb::none(),
|
"axis"_a = nb::none(),
|
||||||
|
Reference in New Issue
Block a user