LogCumSumExp (#2069)

This commit is contained in:
Yury Popov 2025-04-13 11:27:29 +03:00 committed by GitHub
parent 7275ac7523
commit e9e268336b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 209 additions and 3 deletions

View File

@ -38,6 +38,7 @@ Array
array.log10
array.log1p
array.log2
array.logcumsumexp
array.logsumexp
array.max
array.mean

View File

@ -103,6 +103,7 @@ Operations
log10
log1p
logaddexp
logcumsumexp
logical_not
logical_and
logical_or

View File

@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
@ -226,6 +227,16 @@ void scan_dispatch(
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
}
}

View File

@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/metal/kernels/binary_ops.h"
#define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \
@ -139,6 +141,29 @@ struct CumMin {
}
};
template <typename U>
struct CumLogaddexp {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return LogAddExp{}(a, static_cast<U>(b));
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_and_fill_up(x, init, i);
x = LogAddExp{}(x, other);
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {

View File

@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on

View File

@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
case Scan::Min:
reduce_type = "min";
break;
case Scan::LogAddExp:
reduce_type = "logaddexp";
break;
}
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(

View File

@ -298,7 +298,13 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(Reshape),
SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"),
SERIALIZE_PRIMITIVE(Round),
SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"),
SERIALIZE_PRIMITIVE(
Scan,
"CumSum",
"CumProd",
"CumMin",
"CumMax",
"CumLogaddexp"),
SERIALIZE_PRIMITIVE(Scatter),
SERIALIZE_PRIMITIVE(Select),
SERIALIZE_PRIMITIVE(Sigmoid),

View File

@ -3504,6 +3504,28 @@ array cummin(
{a});
}
array logcumsumexp(
const array& a,
int axis,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
int ndim = a.ndim();
if (axis >= ndim || axis < -ndim) {
std::ostringstream msg;
msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
axis = (axis + a.ndim()) % a.ndim();
return array(
a.shape(),
a.dtype(),
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive),
{a});
}
/** Convolution operations */
namespace {

View File

@ -715,6 +715,14 @@ array topk(const array& a, int k, StreamOrDevice s = {});
/** Returns topk elements of the array along a given axis. */
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
/** Cumulative logsumexp of an array. */
array logcumsumexp(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** The logsumexp of all elements of the array. */
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
inline array logsumexp(const array& a, StreamOrDevice s = {}) {

View File

@ -3478,6 +3478,45 @@ std::vector<array> Scan::vjp(
if (reduce_type_ == Scan::Sum) {
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
} else if (reduce_type_ == Scan::LogAddExp) {
// Ref:
// https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
auto x = primals[0];
auto grad = cotangents[0];
auto results = outputs[0];
auto zero = zeros({1}, grad.dtype(), stream());
auto grad_min = array(finfo(grad.dtype()).min, grad.dtype());
// Split the incoming gradient into positive and negative part
// in order to take logs. This is required for stable results.
auto log_abs_grad = log(abs(grad, stream()), stream());
auto log_grad_positive =
where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream());
auto log_grad_negative =
where(less(grad, zero, stream()), log_abs_grad, grad_min, stream());
auto output_pos = exp(
add(logcumsumexp(
subtract(log_grad_positive, results, stream()),
axis_,
!reverse_,
inclusive_,
stream()),
x,
stream()));
auto output_neg = exp(
add(logcumsumexp(
subtract(log_grad_negative, results, stream()),
axis_,
!reverse_,
inclusive_,
stream()),
x,
stream()));
return {subtract(output_pos, output_neg, stream())};
} else if (reduce_type_ == Scan::Prod) {
auto in = primals[0];
// Find the location of the first 0 and set it to 1:

View File

@ -1728,7 +1728,7 @@ class Round : public UnaryPrimitive {
class Scan : public UnaryPrimitive {
public:
enum ReduceType { Max, Min, Sum, Prod };
enum ReduceType { Max, Min, Sum, Prod, LogAddExp };
explicit Scan(
Stream stream,
@ -1763,6 +1763,9 @@ class Scan : public UnaryPrimitive {
case Max:
os << "Max";
break;
case LogAddExp:
os << "Logaddexp";
break;
}
}
bool is_equivalent(const Primitive& other) const override;

View File

@ -1202,6 +1202,28 @@ void init_array(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
"See :func:`max`.")
.def(
"logcumsumexp",
[](const mx::array& a,
std::optional<int> axis,
bool reverse,
bool inclusive,
mx::StreamOrDevice s) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"stream"_a = nb::none(),
"See :func:`logcumsumexp`.")
.def(
"logsumexp",
[](const mx::array& a,

View File

@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) {
Returns:
array: The output array with the corresponding axes reduced.
)pbdoc");
m.def(
"logcumsumexp",
[](const mx::array& a,
std::optional<int> axis,
bool reverse,
bool inclusive,
mx::StreamOrDevice s) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
}
},
nb::arg(),
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"stream"_a = nb::none(),
nb::sig(
"def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the cumulative logsumexp of the elements along the given axis.
Args:
a (array): Input array
axis (int, optional): Optional axis to compute the cumulative logsumexp
over. If unspecified the cumulative logsumexp of the flattened array is
returned.
reverse (bool): Perform the cumulative logsumexp in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
Returns:
array: The output array.
)pbdoc");
m.def(
"logsumexp",
[](const mx::array& a,

View File

@ -1508,6 +1508,7 @@ class TestArray(mlx_tests.MLXTestCase):
("prod", 1),
("min", 1),
("max", 1),
("logcumsumexp", 1),
("logsumexp", 1),
("mean", 1),
("var", 1),

View File

@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase):
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
self.assertTrue(mx.array_equal(y, x[::-1]))
def test_logcumsumexp(self):
npop = np.logaddexp.accumulate
mxop = mx.logcumsumexp
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)
for axis in (0, 1, 2):
c_npy = npop(a_npy, axis=axis)
c_mlx = mxop(a_mlx, axis=axis)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
edge_cases_npy = [
np.float32([-float("inf")] * 8),
np.float32([-float("inf"), 0, -float("inf")]),
np.float32([-float("inf"), float("inf"), -float("inf")]),
]
edge_cases_mlx = [mx.array(a) for a in edge_cases_npy]
for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx):
c_npy = npop(a_npy, axis=0)
c_mlx = mxop(a_mlx, axis=0)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
def test_scans(self):
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)