From c830b5a9f90bd25d6341341f94cd33dede6315dd Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 8 Jun 2025 17:33:49 -0700 Subject: [PATCH] fix metal kernel linking issue on cuda --- mlx/backend/metal/no_metal.cpp | 24 ++++++++++++++++++++++-- mlx/backend/no_gpu/primitives.cpp | 13 ------------- python/tests/test_array.py | 2 +- python/tests/test_device.py | 8 ++++---- python/tests/test_optimizers.py | 2 +- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp index b6142b280..9785e07c2 100644 --- a/mlx/backend/metal/no_metal.cpp +++ b/mlx/backend/metal/no_metal.cpp @@ -3,8 +3,11 @@ #include #include "mlx/backend/metal/metal.h" +#include "mlx/fast.h" -namespace mlx::core::metal { +namespace mlx::core { + +namespace metal { bool is_available() { return false; @@ -19,4 +22,21 @@ device_info() { "[metal::device_info] Cannot get device info without metal backend"); }; -} // namespace mlx::core::metal +} // namespace metal + +namespace fast { + +MetalKernelFunction metal_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool ensure_row_contiguous, + bool atomic_outputs) { + throw std::runtime_error("[metal_kernel] No GPU back-end."); +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 849cbf83e..409aa2c89 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -2,7 +2,6 @@ #include "mlx/primitives.h" #include "mlx/distributed/primitives.h" -#include "mlx/fast.h" #include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ @@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) - -MetalKernelFunction metal_kernel( - const std::string&, - const std::vector&, - const std::vector&, - const std::string&, - const std::string&, - bool ensure_row_contiguous, - bool atomic_outputs) { - throw std::runtime_error("[metal_kernel] No GPU back-end."); -} - } // namespace fast namespace distributed { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index e63da17df..c22e0a38f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase): def test_dlx_device_type(self): a = mx.array([1, 2, 3]) device_type, device_id = a.__dlpack_device__() - self.assertIn(device_type, [1, 8]) + self.assertIn(device_type, [1, 8, 13]) self.assertEqual(device_id, 0) if device_type == 8: diff --git a/python/tests/test_device.py b/python/tests/test_device.py index 53826cad7..6793c98d1 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -10,7 +10,7 @@ import mlx_tests class TestDefaultDevice(unittest.TestCase): def test_mlx_default_device(self): device = mx.default_device() - if mx.metal.is_available(): + if mx.is_available(mx.gpu): self.assertEqual(device, mx.Device(mx.gpu)) self.assertEqual(str(device), "Device(gpu, 0)") self.assertEqual(device, mx.gpu) @@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase): self.assertEqual(s2.device, mx.default_device()) self.assertNotEqual(s1, s2) - if mx.metal.is_available(): + if mx.is_available(mx.gpu): s_gpu = mx.default_stream(mx.gpu) self.assertEqual(s_gpu.device, mx.gpu) else: @@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase): s_cpu = mx.new_stream(mx.cpu) self.assertEqual(s_cpu.device, mx.cpu) - if mx.metal.is_available(): + if mx.is_available(mx.gpu): s_gpu = mx.new_stream(mx.gpu) self.assertEqual(s_gpu.device, mx.gpu) else: @@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase): a = mx.add(x, y, stream=mx.default_stream(mx.default_device())) - if mx.metal.is_available(): + if mx.is_available(mx.gpu): b = mx.add(x, y, stream=mx.default_stream(mx.gpu)) self.assertEqual(a.item(), b.item()) s_gpu = mx.new_stream(mx.gpu) diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index ebfe97d80..4943fe662 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) -class TestSchedulers(unittest.TestCase): +class TestSchedulers(mlx_tests.MLXTestCase): def test_decay_lr(self): for optim_class in optimizers_dict.values(): lr_schedule = opt.step_decay(1e-1, 0.9, 1)