start cuda circle config (#2256)

* rebase

* fix metal kernel linking issue on cuda

* start cuda circle config
This commit is contained in:
Awni Hannun
2025-06-10 21:19:47 -07:00
committed by GitHub
parent 8590c0941e
commit c35f4d089a
14 changed files with 101 additions and 26 deletions

View File

@@ -10,7 +10,7 @@ import mlx_tests
class TestDefaultDevice(unittest.TestCase):
def test_mlx_default_device(self):
device = mx.default_device()
if mx.metal.is_available():
if mx.is_available(mx.gpu):
self.assertEqual(device, mx.Device(mx.gpu))
self.assertEqual(str(device), "Device(gpu, 0)")
self.assertEqual(device, mx.gpu)
@@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase):
self.assertEqual(s2.device, mx.default_device())
self.assertNotEqual(s1, s2)
if mx.metal.is_available():
if mx.is_available(mx.gpu):
s_gpu = mx.default_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
@@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase):
s_cpu = mx.new_stream(mx.cpu)
self.assertEqual(s_cpu.device, mx.cpu)
if mx.metal.is_available():
if mx.is_available(mx.gpu):
s_gpu = mx.new_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
@@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase):
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
if mx.metal.is_available():
if mx.is_available(mx.gpu):
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
self.assertEqual(a.item(), b.item())
s_gpu = mx.new_stream(mx.gpu)