mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
LogCumSumExp (#2069)
This commit is contained in:
@@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
|
||||
self.assertTrue(mx.array_equal(y, x[::-1]))
|
||||
|
||||
def test_logcumsumexp(self):
|
||||
npop = np.logaddexp.accumulate
|
||||
mxop = mx.logcumsumexp
|
||||
|
||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
||||
for axis in (0, 1, 2):
|
||||
c_npy = npop(a_npy, axis=axis)
|
||||
c_mlx = mxop(a_mlx, axis=axis)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
edge_cases_npy = [
|
||||
np.float32([-float("inf")] * 8),
|
||||
np.float32([-float("inf"), 0, -float("inf")]),
|
||||
np.float32([-float("inf"), float("inf"), -float("inf")]),
|
||||
]
|
||||
edge_cases_mlx = [mx.array(a) for a in edge_cases_npy]
|
||||
|
||||
for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx):
|
||||
c_npy = npop(a_npy, axis=0)
|
||||
c_mlx = mxop(a_mlx, axis=0)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
def test_scans(self):
|
||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
Reference in New Issue
Block a user