LogCumSumExp (#2069)

This commit is contained in:
Yury Popov
2025-04-13 11:27:29 +03:00
committed by GitHub
parent 7275ac7523
commit e9e268336b
15 changed files with 209 additions and 3 deletions

View File

@@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase):
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
self.assertTrue(mx.array_equal(y, x[::-1]))
def test_logcumsumexp(self):
npop = np.logaddexp.accumulate
mxop = mx.logcumsumexp
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)
for axis in (0, 1, 2):
c_npy = npop(a_npy, axis=axis)
c_mlx = mxop(a_mlx, axis=axis)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
edge_cases_npy = [
np.float32([-float("inf")] * 8),
np.float32([-float("inf"), 0, -float("inf")]),
np.float32([-float("inf"), float("inf"), -float("inf")]),
]
edge_cases_mlx = [mx.array(a) for a in edge_cases_npy]
for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx):
c_npy = npop(a_npy, axis=0)
c_mlx = mxop(a_mlx, axis=0)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
def test_scans(self):
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)