Fix multistream GPU deadlock (#1969)

* fix multistream GPU deadlock

* comments
This commit is contained in:
Awni Hannun
2025-03-20 07:19:47 -07:00
committed by GitHub
parent 95e335db7b
commit 3c164fca8c
6 changed files with 31 additions and 17 deletions

View File

@@ -174,6 +174,17 @@ class TestEval(mlx_tests.MLXTestCase):
post = mx.metal.get_peak_memory()
self.assertEqual(pre, post)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_multistream_deadlock(self):
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
x = mx.array(1.0)
x = mx.abs(x, stream=s1)
for _ in range(1000):
x = mx.abs(x, stream=s2)
mx.eval(x)
if __name__ == "__main__":
unittest.main()