fix metal kernel linking issue on cuda

This commit is contained in:
Awni Hannun 2025-06-08 17:33:49 -07:00
parent 283a136c64
commit c830b5a9f9
5 changed files with 28 additions and 21 deletions

View File

@ -3,8 +3,11 @@
#include <stdexcept>
#include "mlx/backend/metal/metal.h"
#include "mlx/fast.h"
namespace mlx::core::metal {
namespace mlx::core {
namespace metal {
bool is_available() {
return false;
@ -19,4 +22,21 @@ device_info() {
"[metal::device_info] Cannot get device info without metal backend");
};
} // namespace mlx::core::metal
} // namespace metal
namespace fast {
MetalKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
}
} // namespace fast
} // namespace mlx::core

View File

@ -2,7 +2,6 @@
#include "mlx/primitives.h"
#include "mlx/distributed/primitives.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#define NO_GPU_MULTI(func) \
@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
MetalKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
}
} // namespace fast
namespace distributed {

View File

@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase):
def test_dlx_device_type(self):
a = mx.array([1, 2, 3])
device_type, device_id = a.__dlpack_device__()
self.assertIn(device_type, [1, 8])
self.assertIn(device_type, [1, 8, 13])
self.assertEqual(device_id, 0)
if device_type == 8:

View File

@ -10,7 +10,7 @@ import mlx_tests
class TestDefaultDevice(unittest.TestCase):
def test_mlx_default_device(self):
device = mx.default_device()
if mx.metal.is_available():
if mx.is_available(mx.gpu):
self.assertEqual(device, mx.Device(mx.gpu))
self.assertEqual(str(device), "Device(gpu, 0)")
self.assertEqual(device, mx.gpu)
@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase):
self.assertEqual(s2.device, mx.default_device())
self.assertNotEqual(s1, s2)
if mx.metal.is_available():
if mx.is_available(mx.gpu):
s_gpu = mx.default_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase):
s_cpu = mx.new_stream(mx.cpu)
self.assertEqual(s_cpu.device, mx.cpu)
if mx.metal.is_available():
if mx.is_available(mx.gpu):
s_gpu = mx.new_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase):
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
if mx.metal.is_available():
if mx.is_available(mx.gpu):
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
self.assertEqual(a.item(), b.item())
s_gpu = mx.new_stream(mx.gpu)

View File

@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
class TestSchedulers(unittest.TestCase):
class TestSchedulers(mlx_tests.MLXTestCase):
def test_decay_lr(self):
for optim_class in optimizers_dict.values():
lr_schedule = opt.step_decay(1e-1, 0.9, 1)