mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
tests: add complex scans
This commit is contained in:
parent
a19b21ad22
commit
26d0e56e9f
@ -1892,6 +1892,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||||
a_mlx = mx.array(a_npy)
|
a_mlx = mx.array(a_npy)
|
||||||
|
|
||||||
|
for op in ["cumsum", "cumprod"]:
|
||||||
|
npop = getattr(np, op)
|
||||||
|
mxop = getattr(mx, op)
|
||||||
|
for axis in (None, 0, 1, 2):
|
||||||
|
c_npy = npop(a_npy, axis=axis)
|
||||||
|
c_mlx = mxop(a_mlx, axis=axis)
|
||||||
|
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||||
|
|
||||||
|
# Complex test
|
||||||
|
|
||||||
|
a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
|
||||||
for op in ["cumsum", "cumprod"]:
|
for op in ["cumsum", "cumprod"]:
|
||||||
npop = getattr(np, op)
|
npop = getattr(np, op)
|
||||||
mxop = getattr(mx, op)
|
mxop = getattr(mx, op)
|
||||||
|
Loading…
Reference in New Issue
Block a user