LogCumSumExp (#2069)

This commit is contained in:
Yury Popov
2025-04-13 11:27:29 +03:00
committed by GitHub
parent 7275ac7523
commit e9e268336b
15 changed files with 209 additions and 3 deletions

View File

@@ -3504,6 +3504,28 @@ array cummin(
{a});
}
array logcumsumexp(
const array& a,
int axis,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
int ndim = a.ndim();
if (axis >= ndim || axis < -ndim) {
std::ostringstream msg;
msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
axis = (axis + a.ndim()) % a.ndim();
return array(
a.shape(),
a.dtype(),
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive),
{a});
}
/** Convolution operations */
namespace {