mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
LogCumSumExp (#2069)
This commit is contained in:
22
mlx/ops.cpp
22
mlx/ops.cpp
@@ -3504,6 +3504,28 @@ array cummin(
|
||||
{a});
|
||||
}
|
||||
|
||||
array logcumsumexp(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool reverse /* = false*/,
|
||||
bool inclusive /* = true*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
int ndim = a.ndim();
|
||||
if (axis >= ndim || axis < -ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
axis = (axis + a.ndim()) % a.ndim();
|
||||
return array(
|
||||
a.shape(),
|
||||
a.dtype(),
|
||||
std::make_shared<Scan>(
|
||||
to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive),
|
||||
{a});
|
||||
}
|
||||
|
||||
/** Convolution operations */
|
||||
|
||||
namespace {
|
||||
|
||||
Reference in New Issue
Block a user