From e9e268336b464839034f3f9260ad841380865338 Mon Sep 17 00:00:00 2001 From: Yury Popov Date: Sun, 13 Apr 2025 11:27:29 +0300 Subject: [PATCH] LogCumSumExp (#2069) --- docs/src/python/array.rst | 1 + docs/src/python/ops.rst | 1 + mlx/backend/cpu/scan.cpp | 11 ++++++++ mlx/backend/metal/kernels/scan.h | 25 ++++++++++++++++++ mlx/backend/metal/kernels/scan.metal | 5 +++- mlx/backend/metal/scan.cpp | 3 +++ mlx/export.cpp | 8 +++++- mlx/ops.cpp | 22 ++++++++++++++++ mlx/ops.h | 8 ++++++ mlx/primitives.cpp | 39 ++++++++++++++++++++++++++++ mlx/primitives.h | 5 +++- python/src/array.cpp | 22 ++++++++++++++++ python/src/ops.cpp | 37 ++++++++++++++++++++++++++ python/tests/test_array.py | 1 + python/tests/test_ops.py | 24 +++++++++++++++++ 15 files changed, 209 insertions(+), 3 deletions(-) diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index 532bb45c9..7e1c3339d 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -38,6 +38,7 @@ Array array.log10 array.log1p array.log2 + array.logcumsumexp array.logsumexp array.max array.mean diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 66c5764ed..55fc1f534 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -103,6 +103,7 @@ Operations log10 log1p logaddexp + logcumsumexp logical_not logical_and logical_or diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 1a44ebd39..199dbab35 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -3,6 +3,7 @@ #include #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" @@ -226,6 +227,16 @@ void scan_dispatch( scan_op(in, out, axis, reverse, inclusive, op, init); break; } + case Scan::LogAddExp: { + auto op = [](U a, T b) { + return detail::LogAddExp{}(a, static_cast(b)); + }; + auto init = (issubdtype(in.dtype(), floating)) + ? static_cast(-std::numeric_limits::infinity()) + : std::numeric_limits::min(); + scan_op(in, out, axis, reverse, inclusive, op, init); + break; + } } } diff --git a/mlx/backend/metal/kernels/scan.h b/mlx/backend/metal/kernels/scan.h index cfa84c04c..cb5147558 100644 --- a/mlx/backend/metal/kernels/scan.h +++ b/mlx/backend/metal/kernels/scan.h @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/metal/kernels/binary_ops.h" + #define DEFINE_SIMD_SCAN() \ template = true> \ T simd_scan(T val) { \ @@ -139,6 +141,29 @@ struct CumMin { } }; +template +struct CumLogaddexp { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return LogAddExp{}(a, static_cast(b)); + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_and_fill_up(x, init, i); + x = LogAddExp{}(x, other); + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + template inline void load_unsafe(U values[N_READS], const device T* input) { if (reverse) { diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 6aa36f5a3..8fcd7f61b 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) -instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on +instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) +instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index c7e0087b7..b1800fea9 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { case Scan::Min: reduce_type = "min"; break; + case Scan::LogAddExp: + reduce_type = "logaddexp"; + break; } kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out); auto kernel = get_scan_kernel( diff --git a/mlx/export.cpp b/mlx/export.cpp index 8051f786c..51a04a59f 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -298,7 +298,13 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Reshape), SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"), SERIALIZE_PRIMITIVE(Round), - SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"), + SERIALIZE_PRIMITIVE( + Scan, + "CumSum", + "CumProd", + "CumMin", + "CumMax", + "CumLogaddexp"), SERIALIZE_PRIMITIVE(Scatter), SERIALIZE_PRIMITIVE(Select), SERIALIZE_PRIMITIVE(Sigmoid), diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2f3997e7b..6d1116905 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3504,6 +3504,28 @@ array cummin( {a}); } +array logcumsumexp( + const array& a, + int axis, + bool reverse /* = false*/, + bool inclusive /* = true*/, + StreamOrDevice s /* = {}*/) { + int ndim = a.ndim(); + if (axis >= ndim || axis < -ndim) { + std::ostringstream msg; + msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + axis = (axis + a.ndim()) % a.ndim(); + return array( + a.shape(), + a.dtype(), + std::make_shared( + to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive), + {a}); +} + /** Convolution operations */ namespace { diff --git a/mlx/ops.h b/mlx/ops.h index 02428b974..ce3d4ff44 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -715,6 +715,14 @@ array topk(const array& a, int k, StreamOrDevice s = {}); /** Returns topk elements of the array along a given axis. */ array topk(const array& a, int k, int axis, StreamOrDevice s = {}); +/** Cumulative logsumexp of an array. */ +array logcumsumexp( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + /** The logsumexp of all elements of the array. */ array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {}); inline array logsumexp(const array& a, StreamOrDevice s = {}) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 6f9e45313..8328d96da 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3478,6 +3478,45 @@ std::vector Scan::vjp( if (reduce_type_ == Scan::Sum) { return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; + } else if (reduce_type_ == Scan::LogAddExp) { + // Ref: + // https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863 + + auto x = primals[0]; + auto grad = cotangents[0]; + auto results = outputs[0]; + + auto zero = zeros({1}, grad.dtype(), stream()); + auto grad_min = array(finfo(grad.dtype()).min, grad.dtype()); + + // Split the incoming gradient into positive and negative part + // in order to take logs. This is required for stable results. + auto log_abs_grad = log(abs(grad, stream()), stream()); + auto log_grad_positive = + where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream()); + auto log_grad_negative = + where(less(grad, zero, stream()), log_abs_grad, grad_min, stream()); + + auto output_pos = exp( + add(logcumsumexp( + subtract(log_grad_positive, results, stream()), + axis_, + !reverse_, + inclusive_, + stream()), + x, + stream())); + auto output_neg = exp( + add(logcumsumexp( + subtract(log_grad_negative, results, stream()), + axis_, + !reverse_, + inclusive_, + stream()), + x, + stream())); + + return {subtract(output_pos, output_neg, stream())}; } else if (reduce_type_ == Scan::Prod) { auto in = primals[0]; // Find the location of the first 0 and set it to 1: diff --git a/mlx/primitives.h b/mlx/primitives.h index c7b2de878..7738b273b 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1728,7 +1728,7 @@ class Round : public UnaryPrimitive { class Scan : public UnaryPrimitive { public: - enum ReduceType { Max, Min, Sum, Prod }; + enum ReduceType { Max, Min, Sum, Prod, LogAddExp }; explicit Scan( Stream stream, @@ -1763,6 +1763,9 @@ class Scan : public UnaryPrimitive { case Max: os << "Max"; break; + case LogAddExp: + os << "Logaddexp"; + break; } } bool is_equivalent(const Primitive& other) const override; diff --git a/python/src/array.cpp b/python/src/array.cpp index e380f2652..467bd0fa5 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1202,6 +1202,28 @@ void init_array(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), "See :func:`max`.") + .def( + "logcumsumexp", + [](const mx::array& a, + std::optional axis, + bool reverse, + bool inclusive, + mx::StreamOrDevice s) { + if (axis) { + return mx::logcumsumexp(a, *axis, reverse, inclusive, s); + } else { + // TODO: Implement that in the C++ API as well. See concatenate + // above. + return mx::logcumsumexp( + mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "axis"_a = nb::none(), + nb::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = nb::none(), + "See :func:`logcumsumexp`.") .def( "logsumexp", [](const mx::array& a, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 5f078a08d..7f06a4ddf 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) { Returns: array: The output array with the corresponding axes reduced. )pbdoc"); + m.def( + "logcumsumexp", + [](const mx::array& a, + std::optional axis, + bool reverse, + bool inclusive, + mx::StreamOrDevice s) { + if (axis) { + return mx::logcumsumexp(a, *axis, reverse, inclusive, s); + } else { + return mx::logcumsumexp( + mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = nb::none(), + nb::sig( + "def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return the cumulative logsumexp of the elements along the given axis. + + Args: + a (array): Input array + axis (int, optional): Optional axis to compute the cumulative logsumexp + over. If unspecified the cumulative logsumexp of the flattened array is + returned. + reverse (bool): Perform the cumulative logsumexp in reverse. + inclusive (bool): The i-th element of the output includes the i-th + element of the input. + + Returns: + array: The output array. + )pbdoc"); m.def( "logsumexp", [](const mx::array& a, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 67dc7c84b..fa5784ea9 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1508,6 +1508,7 @@ class TestArray(mlx_tests.MLXTestCase): ("prod", 1), ("min", 1), ("max", 1), + ("logcumsumexp", 1), ("logsumexp", 1), ("mean", 1), ("var", 1), diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index a71d2c253..4fcb31f18 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase): y = mx.as_strided(x, (x.size,), (-1,), x.size - 1) self.assertTrue(mx.array_equal(y, x[::-1])) + def test_logcumsumexp(self): + npop = np.logaddexp.accumulate + mxop = mx.logcumsumexp + + a_npy = np.random.randn(32, 32, 32).astype(np.float32) + a_mlx = mx.array(a_npy) + + for axis in (0, 1, 2): + c_npy = npop(a_npy, axis=axis) + c_mlx = mxop(a_mlx, axis=axis) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + + edge_cases_npy = [ + np.float32([-float("inf")] * 8), + np.float32([-float("inf"), 0, -float("inf")]), + np.float32([-float("inf"), float("inf"), -float("inf")]), + ] + edge_cases_mlx = [mx.array(a) for a in edge_cases_npy] + + for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx): + c_npy = npop(a_npy, axis=0) + c_mlx = mxop(a_mlx, axis=0) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + def test_scans(self): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy)