mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
* updated test_array for missing ops * formatting changes
This commit is contained in:
parent
41c603d48a
commit
11371fe251
@ -1436,6 +1436,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
"sin",
|
||||
"cos",
|
||||
"log1p",
|
||||
"abs",
|
||||
"log10",
|
||||
"log2",
|
||||
"conj",
|
||||
("all", 1),
|
||||
("any", 1),
|
||||
("transpose", (0, 2, 1)),
|
||||
@ -1448,6 +1452,16 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
("var", 1),
|
||||
("argmin", 1),
|
||||
("argmax", 1),
|
||||
("cummax", 1),
|
||||
("cummin", 1),
|
||||
("cumprod", 1),
|
||||
("cumsum", 1),
|
||||
("diagonal", 0, 0, 1),
|
||||
("flatten", 0, -1),
|
||||
("moveaxis", 1, 2),
|
||||
("round", 2),
|
||||
("std", 1, True, 0),
|
||||
("swapaxes", 1, 2),
|
||||
]
|
||||
for op in ops:
|
||||
if isinstance(op, tuple):
|
||||
@ -1466,6 +1480,11 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(len(y1), len(y2))
|
||||
self.assertTrue(mx.array_equal(y1[0], y2[0]))
|
||||
self.assertTrue(mx.array_equal(y1[1], y2[1]))
|
||||
x = mx.array(np.random.rand(10, 10, 1))
|
||||
y1 = mx.squeeze(x, axis=2)
|
||||
y2 = x.squeeze(axis=2)
|
||||
self.assertEqual(y1.shape, y2.shape)
|
||||
self.assertTrue(mx.array_equal(y1, y2))
|
||||
|
||||
def test_memoryless_copy(self):
|
||||
a_mx = mx.ones((2, 2))
|
||||
|
Loading…
Reference in New Issue
Block a user