From ce99b8848be6d4f7f86091c4dbbbcd751b527bcb Mon Sep 17 00:00:00 2001 From: Yury Popov Date: Sun, 20 Apr 2025 16:43:03 +0300 Subject: [PATCH] tests: complex logaddexp / logcumsumexp --- python/tests/test_ops.py | 59 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 78da34f7f..13b255f8d 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -10,6 +10,47 @@ import mlx_tests import numpy as np +def np_wrap_between(x, a): + """Wraps `x` between `[-a, a]`.""" + two_a = 2 * a + zero = 0 + rem = np.remainder(np.add(x, a), two_a) + if isinstance(rem, np.ndarray): + rem = np.select(rem < zero, np.add(rem, two_a), rem) + else: + rem = np.add(rem, two_a) if rem < zero else rem + return np.subtract(rem, a) + + +def np_logaddexp(x1: np.ndarray, x2: np.ndarray): + amax = np.maximum(x1, x2) + if np.issubdtype(x1.dtype, np.floating): + delta = np.subtract(x1, x2) + if isinstance(delta, np.ndarray): + return np.select( + np.isnan(delta), + np.add(x1, x2), + np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))), + ) + else: + return ( + np.add(x1, x2) + if np.isnan(delta) + else np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))) + ) + else: + delta = np.subtract(np.add(x1, x2), np.multiply(amax, 2)) + out = np.add(amax, np.log1p(np.exp(delta))) + return np.real(out) + 1j * np_wrap_between(np.imag(out), np.pi) + + +def np_cumlogaddexp(x1: np.ndarray, axis: int = -1): + out = x1.copy() + for i in range(1, out.shape[axis]): + out[i] = np_logaddexp(out[i], out[i - 1]) + return out + + class TestOps(mlx_tests.MLXTestCase): def test_full_ones_zeros(self): x = mx.full(2, 3.0) @@ -853,6 +894,16 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + # Complex test + + a = mx.array([0, 1, 2, 9.0]) + 1j + b = mx.array([1, 0, 4, 2.5]) + 1j + + result = mx.logaddexp(a, b) + expected = np_logaddexp(np.array(a), np.array(b)) + + self.assertTrue(np.allclose(result, expected)) + a = mx.array([float("nan")]) b = mx.array([0.0]) self.assertTrue(math.isnan(mx.logaddexp(a, b).item())) @@ -1888,6 +1939,14 @@ class TestOps(mlx_tests.MLXTestCase): c_mlx = mxop(a_mlx, axis=0) self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + # Complex tests + + a_npy = np.array([1, 2, 3]).astype(np.float32) + 1j + a_mlx = mx.array(a_npy) + c_npy = np_cumlogaddexp(a_npy, axis=-1) + c_mlx = mxop(a_mlx, axis=-1) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + def test_scans(self): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy)