Async eval (#972)

This commit is contained in:
Awni Hannun
2024-04-09 18:34:00 -07:00
committed by GitHub
parent fffe072028
commit 99abb9eff4
5 changed files with 70 additions and 3 deletions

View File

@@ -32,6 +32,18 @@ class TestEval(mlx_tests.MLXTestCase):
mx.eval(state)
self.assertEqual(x.item(), 3)
def test_async_eval(self):
x = mx.array(1) + mx.array(1) + mx.array(1)
sync = mx.async_eval(x)
sync.wait()
self.assertEqual(x.item(), 3)
# It should be safe to call eval on the array which has been async
# eval'ed
x = mx.array(1) + mx.array(1) + mx.array(1)
sync = mx.async_eval(x)
self.assertEqual(x.item(), 3)
if __name__ == "__main__":
unittest.main()