LogCumSumExp (#2069)

This commit is contained in:
Yury Popov
2025-04-13 11:27:29 +03:00
committed by GitHub
parent 7275ac7523
commit e9e268336b
15 changed files with 209 additions and 3 deletions

View File

@@ -1202,6 +1202,28 @@ void init_array(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
"See :func:`max`.")
.def(
"logcumsumexp",
[](const mx::array& a,
std::optional<int> axis,
bool reverse,
bool inclusive,
mx::StreamOrDevice s) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"stream"_a = nb::none(),
"See :func:`logcumsumexp`.")
.def(
"logsumexp",
[](const mx::array& a,

View File

@@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) {
Returns:
array: The output array with the corresponding axes reduced.
)pbdoc");
m.def(
"logcumsumexp",
[](const mx::array& a,
std::optional<int> axis,
bool reverse,
bool inclusive,
mx::StreamOrDevice s) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
}
},
nb::arg(),
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"stream"_a = nb::none(),
nb::sig(
"def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the cumulative logsumexp of the elements along the given axis.
Args:
a (array): Input array
axis (int, optional): Optional axis to compute the cumulative logsumexp
over. If unspecified the cumulative logsumexp of the flattened array is
returned.
reverse (bool): Perform the cumulative logsumexp in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"logsumexp",
[](const mx::array& a,