Adds device context manager (#679)

This commit is contained in:
Diogo
2024-02-14 17:14:58 -05:00
committed by GitHub
parent ccf1645995
commit 35431a4ac8
15 changed files with 230 additions and 77 deletions

View File

@@ -38,6 +38,17 @@ class TestDevice(mlx_tests.MLXTestCase):
# Restore device
mx.set_default_device(device)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_device_context(self):
default = mx.default_device()
diff = mx.cpu if default == mx.gpu else mx.gpu
self.assertNotEqual(default, diff)
with mx.stream(diff):
a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2)))
mx.eval(a)
self.assertEqual(mx.default_device(), diff)
self.assertEqual(mx.default_device(), default)
def test_op_on_device(self):
x = mx.array(1.0)
y = mx.array(1.0)