mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
tests: complex logaddexp / logcumsumexp
This commit is contained in:
parent
26d0e56e9f
commit
ce99b8848b
@ -10,6 +10,47 @@ import mlx_tests
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def np_wrap_between(x, a):
|
||||||
|
"""Wraps `x` between `[-a, a]`."""
|
||||||
|
two_a = 2 * a
|
||||||
|
zero = 0
|
||||||
|
rem = np.remainder(np.add(x, a), two_a)
|
||||||
|
if isinstance(rem, np.ndarray):
|
||||||
|
rem = np.select(rem < zero, np.add(rem, two_a), rem)
|
||||||
|
else:
|
||||||
|
rem = np.add(rem, two_a) if rem < zero else rem
|
||||||
|
return np.subtract(rem, a)
|
||||||
|
|
||||||
|
|
||||||
|
def np_logaddexp(x1: np.ndarray, x2: np.ndarray):
|
||||||
|
amax = np.maximum(x1, x2)
|
||||||
|
if np.issubdtype(x1.dtype, np.floating):
|
||||||
|
delta = np.subtract(x1, x2)
|
||||||
|
if isinstance(delta, np.ndarray):
|
||||||
|
return np.select(
|
||||||
|
np.isnan(delta),
|
||||||
|
np.add(x1, x2),
|
||||||
|
np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
np.add(x1, x2)
|
||||||
|
if np.isnan(delta)
|
||||||
|
else np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta)))))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delta = np.subtract(np.add(x1, x2), np.multiply(amax, 2))
|
||||||
|
out = np.add(amax, np.log1p(np.exp(delta)))
|
||||||
|
return np.real(out) + 1j * np_wrap_between(np.imag(out), np.pi)
|
||||||
|
|
||||||
|
|
||||||
|
def np_cumlogaddexp(x1: np.ndarray, axis: int = -1):
|
||||||
|
out = x1.copy()
|
||||||
|
for i in range(1, out.shape[axis]):
|
||||||
|
out[i] = np_logaddexp(out[i], out[i - 1])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class TestOps(mlx_tests.MLXTestCase):
|
class TestOps(mlx_tests.MLXTestCase):
|
||||||
def test_full_ones_zeros(self):
|
def test_full_ones_zeros(self):
|
||||||
x = mx.full(2, 3.0)
|
x = mx.full(2, 3.0)
|
||||||
@ -853,6 +894,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(np.allclose(result, expected))
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
# Complex test
|
||||||
|
|
||||||
|
a = mx.array([0, 1, 2, 9.0]) + 1j
|
||||||
|
b = mx.array([1, 0, 4, 2.5]) + 1j
|
||||||
|
|
||||||
|
result = mx.logaddexp(a, b)
|
||||||
|
expected = np_logaddexp(np.array(a), np.array(b))
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
a = mx.array([float("nan")])
|
a = mx.array([float("nan")])
|
||||||
b = mx.array([0.0])
|
b = mx.array([0.0])
|
||||||
self.assertTrue(math.isnan(mx.logaddexp(a, b).item()))
|
self.assertTrue(math.isnan(mx.logaddexp(a, b).item()))
|
||||||
@ -1888,6 +1939,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
c_mlx = mxop(a_mlx, axis=0)
|
c_mlx = mxop(a_mlx, axis=0)
|
||||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||||
|
|
||||||
|
# Complex tests
|
||||||
|
|
||||||
|
a_npy = np.array([1, 2, 3]).astype(np.float32) + 1j
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
c_npy = np_cumlogaddexp(a_npy, axis=-1)
|
||||||
|
c_mlx = mxop(a_mlx, axis=-1)
|
||||||
|
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||||
|
|
||||||
def test_scans(self):
|
def test_scans(self):
|
||||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||||
a_mlx = mx.array(a_npy)
|
a_mlx = mx.array(a_npy)
|
||||||
|
Loading…
Reference in New Issue
Block a user