mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
LogCumSumExp (#2069)
This commit is contained in:
@@ -1202,6 +1202,28 @@ void init_array(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
"See :func:`max`.")
|
||||
.def(
|
||||
"logcumsumexp",
|
||||
[](const mx::array& a,
|
||||
std::optional<int> axis,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axis) {
|
||||
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
|
||||
} else {
|
||||
// TODO: Implement that in the C++ API as well. See concatenate
|
||||
// above.
|
||||
return mx::logcumsumexp(
|
||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||
}
|
||||
},
|
||||
"axis"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"reverse"_a = false,
|
||||
"inclusive"_a = true,
|
||||
"stream"_a = nb::none(),
|
||||
"See :func:`logcumsumexp`.")
|
||||
.def(
|
||||
"logsumexp",
|
||||
[](const mx::array& a,
|
||||
|
@@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The output array with the corresponding axes reduced.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"logcumsumexp",
|
||||
[](const mx::array& a,
|
||||
std::optional<int> axis,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axis) {
|
||||
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
|
||||
} else {
|
||||
return mx::logcumsumexp(
|
||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
"axis"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"reverse"_a = false,
|
||||
"inclusive"_a = true,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Return the cumulative logsumexp of the elements along the given axis.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
axis (int, optional): Optional axis to compute the cumulative logsumexp
|
||||
over. If unspecified the cumulative logsumexp of the flattened array is
|
||||
returned.
|
||||
reverse (bool): Perform the cumulative logsumexp in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"logsumexp",
|
||||
[](const mx::array& a,
|
||||
|
Reference in New Issue
Block a user