add python testing for cuda with ability to skip list of tests (#2295)

This commit is contained in:
Awni Hannun
2025-06-15 10:56:48 -07:00
committed by GitHub
parent 580776559b
commit 4fda5fbdf9
36 changed files with 220 additions and 35 deletions

View File

@@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase):
# Restore device
mx.set_default_device(device)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_device_context(self):
default = mx.default_device()
diff = mx.cpu if default == mx.gpu else mx.gpu
@@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()