mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
tests: add complex scans
This commit is contained in:
parent
a19b21ad22
commit
26d0e56e9f
@ -1892,6 +1892,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
||||
for op in ["cumsum", "cumprod"]:
|
||||
npop = getattr(np, op)
|
||||
mxop = getattr(mx, op)
|
||||
for axis in (None, 0, 1, 2):
|
||||
c_npy = npop(a_npy, axis=axis)
|
||||
c_mlx = mxop(a_mlx, axis=axis)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
# Complex test
|
||||
|
||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
||||
for op in ["cumsum", "cumprod"]:
|
||||
npop = getattr(np, op)
|
||||
mxop = getattr(mx, op)
|
||||
|
Loading…
Reference in New Issue
Block a user