basic python tests

This commit is contained in:
Awni Hannun
2024-01-15 06:08:18 -08:00
parent 9739c72781
commit 0005cfe053
3 changed files with 79 additions and 21 deletions

View File

@@ -15,7 +15,12 @@ class TestCompile(mlx_tests.MLXTestCase):
compiled_fn = mx.compile(fun)
x = mx.array(1.0)
y = mx.array(1.0)
# out = compiled_fn(x, y)
out = compiled_fn(x, y)
self.assertEqual(out.item(), 2.0)
# Try again
out = compiled_fn(x, y)
self.assertEqual(out.item(), 2.0)
if __name__ == "__main__":