diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 0508d3362..9b9cc7e21 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1436,6 +1436,10 @@ class TestArray(mlx_tests.MLXTestCase): "sin", "cos", "log1p", + "abs", + "log10", + "log2", + "conj", ("all", 1), ("any", 1), ("transpose", (0, 2, 1)), @@ -1448,6 +1452,16 @@ class TestArray(mlx_tests.MLXTestCase): ("var", 1), ("argmin", 1), ("argmax", 1), + ("cummax", 1), + ("cummin", 1), + ("cumprod", 1), + ("cumsum", 1), + ("diagonal", 0, 0, 1), + ("flatten", 0, -1), + ("moveaxis", 1, 2), + ("round", 2), + ("std", 1, True, 0), + ("swapaxes", 1, 2), ] for op in ops: if isinstance(op, tuple): @@ -1466,6 +1480,11 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual(len(y1), len(y2)) self.assertTrue(mx.array_equal(y1[0], y2[0])) self.assertTrue(mx.array_equal(y1[1], y2[1])) + x = mx.array(np.random.rand(10, 10, 1)) + y1 = mx.squeeze(x, axis=2) + y2 = x.squeeze(axis=2) + self.assertEqual(y1.shape, y2.shape) + self.assertTrue(mx.array_equal(y1, y2)) def test_memoryless_copy(self): a_mx = mx.ones((2, 2))