binding + tests

This commit is contained in:
Awni Hannun
2024-12-06 20:05:00 -08:00
parent 2b9c24c517
commit 0c1155faf5
4 changed files with 39 additions and 5 deletions

View File

@@ -809,6 +809,29 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
def test_compile_shapeless_with_reshape(self):
def fun(a):
return mx.reshape(a, (4, 7, 4, 2))
cfun = mx.compile(fun, shapeless=True)
a = mx.zeros((4, 7, 8))
with self.assertRaises(ValueError):
b = cfun(a)
def fun(a):
return mx.dynamic_reshape(a, ("B", "L", 4, 2))
cfun = mx.compile(fun, shapeless=True)
b = cfun(a)
self.assertEqual(b.shape, (4, 7, 4, 2))
a = mx.zeros((4, 9, 8))
b = cfun(a)
self.assertEqual(b.shape, (4, 9, 4, 2))
if __name__ == "__main__":
unittest.main()

View File

@@ -2718,6 +2718,16 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.dynamic_reshape(a, ())
self.assertEqual(a.shape, ())
a = mx.zeros((4, 4, 4))
b = mx.dynamic_reshape(a, ("a", "b", "c"))
self.assertEqual(b.shape, (4, 4, 4))
b = mx.dynamic_reshape(a, ("a*b", "c"))
self.assertEqual(b.shape, (4 * 4, 4))
b = mx.dynamic_reshape(a, ("a*b*c", 1, 1))
self.assertEqual(b.shape, (4 * 4 * 4, 1, 1))
if __name__ == "__main__":
unittest.main()