From 1d2c9d6a07f76b82b3c80c6a6682ba7f7161865d Mon Sep 17 00:00:00 2001 From: Yury Popov Date: Wed, 23 Apr 2025 04:56:28 +0300 Subject: [PATCH] Complex scan (#2094) --- mlx/backend/cpu/binary.cpp | 5 +- mlx/backend/cpu/scan.cpp | 3 +- mlx/backend/cpu/simd/base_simd.h | 23 +++++++- mlx/backend/metal/kernels/binary.metal | 1 + mlx/backend/metal/kernels/binary_ops.h | 18 ++++++ mlx/backend/metal/kernels/scan.metal | 3 +- mlx/backend/metal/kernels/unary.metal | 1 + mlx/backend/metal/kernels/utils.h | 17 ++++++ python/tests/test_ops.py | 79 ++++++++++++++++++++++++++ 9 files changed, 146 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index dbdab6a06..35aa2a3e0 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -172,9 +172,12 @@ void binary_float( case bfloat16: binary_op(a, b, out, bopt); break; + case complex64: + binary_op(a, b, out, bopt); + break; default: throw std::runtime_error( - "[binary_float] Only supports non-complex floating point types."); + "[binary_float] Only supports floating point types."); } }); } diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 199dbab35..33addd161 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { reduce_type_, in, out, axis_, reverse_, inclusive_); break; case complex64: - throw std::runtime_error("Scan ops do not support complex types yet"); + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); break; } }); diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index 7e82a4d56..17cd35b9a 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log10, std::log10) -DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) +template +Simd log1p(Simd in) { + if constexpr (is_complex) { + auto x = in.value.real(); + auto y = in.value.imag(); + auto zabs = std::abs(in.value); + auto theta = std::atan2(y, x + 1); + if (zabs < 0.5) { + auto r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return Simd{T{x, theta}}; + } + return Simd{T{((typeof(x))(0.5)) * std::log1p(r), theta}}; + } else { + auto z0 = std::hypot(x + 1, y); + return Simd{T{std::log(z0), theta}}; + } + } else { + return Simd{std::log1p(in.value)}; + } +} + template Simd log2(Simd in) { if constexpr (is_complex) { diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 3ef8e6269..1d555fefa 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less) instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(NotEqual) instantiate_binary_float(LogAddExp) +instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_types(Maximum) instantiate_binary_types(Minimum) instantiate_binary_types(Multiply) diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 8f961c2cf..4aaf2b4da 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -130,6 +130,24 @@ struct LogAddExp { ? maxval : (maxval + log1p(metal::exp(minval - maxval))); }; + + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) || + metal::isnan(y.imag)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr float inf = metal::numeric_limits::infinity(); + complex64_t maxval = x > y ? x : y; + complex64_t minval = x < y ? x : y; + if (minval.real == -inf || maxval.real == inf) + return maxval; + float m = metal::exp(minval.real - maxval.real); + complex64_t dexp{ + m * metal::cos(minval.imag - maxval.imag), + m * metal::sin(minval.imag - maxval.imag), + }; + return maxval + log1p(dexp); + } }; struct Maximum { diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 8fcd7f61b..f38f8757e 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) -instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on +instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index d34c5a7ec..afced7eb7 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -77,6 +77,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_all_same(Log, complex64, complex64_t) +instantiate_unary_all_same(Log1p, complex64, complex64_t) instantiate_unary_all_same(Log2, complex64, complex64_t) instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index b31cd20d6..1170d5576 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -328,6 +328,23 @@ inline bfloat16_t log1p(bfloat16_t x) { return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); } +inline complex64_t log1p(complex64_t in) { + float x = in.real; + float y = in.imag; + float zabs = metal::precise::sqrt(x * x + y * y); + float theta = metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y); + return {metal::log(z0), theta}; + } +} + /////////////////////////////////////////////////////////////////////////////// // SIMD shuffle ops /////////////////////////////////////////////////////////////////////////////// diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 31ea79345..d0e52eab2 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -10,6 +10,47 @@ import mlx_tests import numpy as np +def np_wrap_between(x, a): + """Wraps `x` between `[-a, a]`.""" + two_a = 2 * a + zero = 0 + rem = np.remainder(np.add(x, a), two_a) + if isinstance(rem, np.ndarray): + rem = np.select(rem < zero, np.add(rem, two_a), rem) + else: + rem = np.add(rem, two_a) if rem < zero else rem + return np.subtract(rem, a) + + +def np_logaddexp(x1: np.ndarray, x2: np.ndarray): + amax = np.maximum(x1, x2) + if np.issubdtype(x1.dtype, np.floating): + delta = np.subtract(x1, x2) + if isinstance(delta, np.ndarray): + return np.select( + np.isnan(delta), + np.add(x1, x2), + np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))), + ) + else: + return ( + np.add(x1, x2) + if np.isnan(delta) + else np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))) + ) + else: + delta = np.subtract(np.add(x1, x2), np.multiply(amax, 2)) + out = np.add(amax, np.log1p(np.exp(delta))) + return np.real(out) + 1j * np_wrap_between(np.imag(out), np.pi) + + +def np_cumlogaddexp(x1: np.ndarray, axis: int = -1): + out = x1.copy() + for i in range(1, out.shape[axis]): + out[i] = np_logaddexp(out[i], out[i - 1]) + return out + + class TestOps(mlx_tests.MLXTestCase): def test_full_ones_zeros(self): x = mx.full(2, 3.0) @@ -853,6 +894,16 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + # Complex test + + a = mx.array([0, 1, 2, 9.0]) + 1j + b = mx.array([1, 0, 4, 2.5]) + 1j + + result = mx.logaddexp(a, b) + expected = np_logaddexp(np.array(a), np.array(b)) + + self.assertTrue(np.allclose(result, expected)) + a = mx.array([float("nan")]) b = mx.array([0.0]) self.assertTrue(math.isnan(mx.logaddexp(a, b).item())) @@ -977,6 +1028,13 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + # Complex test + a = mx.array([1, 0.5, 10, 100]) + 1j + result = mx.log1p(a) + expected = np.log1p(a, dtype=np.complex64) + + self.assertTrue(np.allclose(result, expected)) + def test_sigmoid(self): a = mx.array([0.0, 1.0, -1.0, 5.0, -5.0]) result = mx.sigmoid(a) @@ -1881,10 +1939,31 @@ class TestOps(mlx_tests.MLXTestCase): c_mlx = mxop(a_mlx, axis=0) self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + # Complex tests + + a_npy = np.array([1, 2, 3]).astype(np.float32) + 1j + a_mlx = mx.array(a_npy) + c_npy = np_cumlogaddexp(a_npy, axis=-1) + c_mlx = mxop(a_mlx, axis=-1) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + def test_scans(self): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: + npop = getattr(np, op) + mxop = getattr(mx, op) + for axis in (None, 0, 1, 2): + c_npy = npop(a_npy, axis=axis) + c_mlx = mxop(a_mlx, axis=axis) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + + # Complex test + + a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j + a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: npop = getattr(np, op) mxop = getattr(mx, op)