tests: add complex scans

This commit is contained in:
Yury Popov 2025-04-20 15:52:39 +03:00
parent a19b21ad22
commit 26d0e56e9f
No known key found for this signature in database
GPG Key ID: 76DE18AD6634F257

View File

@ -1892,6 +1892,19 @@ class TestOps(mlx_tests.MLXTestCase):
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)
for op in ["cumsum", "cumprod"]:
npop = getattr(np, op)
mxop = getattr(mx, op)
for axis in (None, 0, 1, 2):
c_npy = npop(a_npy, axis=axis)
c_mlx = mxop(a_mlx, axis=axis)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
# Complex test
a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j
a_mlx = mx.array(a_npy)
for op in ["cumsum", "cumprod"]:
npop = getattr(np, op)
mxop = getattr(mx, op)