Test to prevent bugs like #1386 (#1391)

* updated test_array for missing ops

* formatting changes
This commit is contained in:
Bhargav Yagnik 2024-09-04 20:24:30 -04:00 committed by GitHub
parent 41c603d48a
commit 11371fe251
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1436,6 +1436,10 @@ class TestArray(mlx_tests.MLXTestCase):
"sin",
"cos",
"log1p",
"abs",
"log10",
"log2",
"conj",
("all", 1),
("any", 1),
("transpose", (0, 2, 1)),
@ -1448,6 +1452,16 @@ class TestArray(mlx_tests.MLXTestCase):
("var", 1),
("argmin", 1),
("argmax", 1),
("cummax", 1),
("cummin", 1),
("cumprod", 1),
("cumsum", 1),
("diagonal", 0, 0, 1),
("flatten", 0, -1),
("moveaxis", 1, 2),
("round", 2),
("std", 1, True, 0),
("swapaxes", 1, 2),
]
for op in ops:
if isinstance(op, tuple):
@ -1466,6 +1480,11 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(len(y1), len(y2))
self.assertTrue(mx.array_equal(y1[0], y2[0]))
self.assertTrue(mx.array_equal(y1[1], y2[1]))
x = mx.array(np.random.rand(10, 10, 1))
y1 = mx.squeeze(x, axis=2)
y2 = x.squeeze(axis=2)
self.assertEqual(y1.shape, y2.shape)
self.assertTrue(mx.array_equal(y1, y2))
def test_memoryless_copy(self):
a_mx = mx.ones((2, 2))