mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Adds device context manager (#679)
This commit is contained in:
@@ -38,6 +38,17 @@ class TestDevice(mlx_tests.MLXTestCase):
|
||||
# Restore device
|
||||
mx.set_default_device(device)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_device_context(self):
|
||||
default = mx.default_device()
|
||||
diff = mx.cpu if default == mx.gpu else mx.gpu
|
||||
self.assertNotEqual(default, diff)
|
||||
with mx.stream(diff):
|
||||
a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2)))
|
||||
mx.eval(a)
|
||||
self.assertEqual(mx.default_device(), diff)
|
||||
self.assertEqual(mx.default_device(), default)
|
||||
|
||||
def test_op_on_device(self):
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
Reference in New Issue
Block a user