mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Make the GPU device more thread safe (#1478)
* gpu stream safety * comment * fix
This commit is contained in:
@@ -122,6 +122,21 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
|
||||
|
||||
def test_async_eval_with_multiple_streams(self):
|
||||
x = mx.array([1.0])
|
||||
y = mx.array([1.0])
|
||||
a = mx.array([1.0])
|
||||
b = mx.array([1.0])
|
||||
|
||||
d = mx.default_device()
|
||||
s2 = mx.new_stream(d)
|
||||
|
||||
for _ in range(50):
|
||||
for _ in range(20):
|
||||
x = x + y
|
||||
mx.async_eval(x)
|
||||
mx.eval(a + b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user