From 26d0e56e9f90456bd5d9e7e2b9d351047c285c04 Mon Sep 17 00:00:00 2001 From: Yury Popov Date: Sun, 20 Apr 2025 15:52:39 +0300 Subject: [PATCH] tests: add complex scans --- python/tests/test_ops.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 89bff378e..78da34f7f 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1892,6 +1892,19 @@ class TestOps(mlx_tests.MLXTestCase): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: + npop = getattr(np, op) + mxop = getattr(mx, op) + for axis in (None, 0, 1, 2): + c_npy = npop(a_npy, axis=axis) + c_mlx = mxop(a_mlx, axis=axis) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + + # Complex test + + a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j + a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: npop = getattr(np, op) mxop = getattr(mx, op)