Make the GPU device more thread safe (#1478)

* gpu stream safety

* comment

* fix
This commit is contained in:
Awni Hannun
2024-10-12 17:49:15 -07:00
committed by GitHub
parent c21331d47f
commit bf6ec92216
8 changed files with 208 additions and 298 deletions

View File

@@ -122,6 +122,21 @@ class TestEval(mlx_tests.MLXTestCase):
out = mx.vjp(fn, (x,), (y,))
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
def test_async_eval_with_multiple_streams(self):
x = mx.array([1.0])
y = mx.array([1.0])
a = mx.array([1.0])
b = mx.array([1.0])
d = mx.default_device()
s2 = mx.new_stream(d)
for _ in range(50):
for _ in range(20):
x = x + y
mx.async_eval(x)
mx.eval(a + b)
if __name__ == "__main__":
unittest.main()