try dynamic reshape

This commit is contained in:
Awni Hannun
2024-12-06 12:09:08 -08:00
parent 40c62c1321
commit ee59d50293
7 changed files with 164 additions and 0 deletions

View File

@@ -2713,6 +2713,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.imag(z).dtype, mx.float32)
self.assertTrue(mx.array_equal(mx.imag(z), y))
def test_dynamic_reshape(self):
a = mx.array(1)[None, None]
a = mx.dynamic_reshape(a, ())
self.assertEqual(a.shape, ())
if __name__ == "__main__":
unittest.main()