mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix metal kernel linking issue on cuda
This commit is contained in:
parent
283a136c64
commit
c830b5a9f9
@ -3,8 +3,11 @@
|
|||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/fast.h"
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return false;
|
return false;
|
||||||
@ -19,4 +22,21 @@ device_info() {
|
|||||||
"[metal::device_info] Cannot get device info without metal backend");
|
"[metal::device_info] Cannot get device info without metal backend");
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace metal
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
MetalKernelFunction metal_kernel(
|
||||||
|
const std::string&,
|
||||||
|
const std::vector<std::string>&,
|
||||||
|
const std::vector<std::string>&,
|
||||||
|
const std::string&,
|
||||||
|
const std::string&,
|
||||||
|
bool ensure_row_contiguous,
|
||||||
|
bool atomic_outputs) {
|
||||||
|
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/fast.h"
|
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE)
|
|||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
NO_GPU_MULTI(AffineQuantize)
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
|
|
||||||
MetalKernelFunction metal_kernel(
|
|
||||||
const std::string&,
|
|
||||||
const std::vector<std::string>&,
|
|
||||||
const std::vector<std::string>&,
|
|
||||||
const std::string&,
|
|
||||||
const std::string&,
|
|
||||||
bool ensure_row_contiguous,
|
|
||||||
bool atomic_outputs) {
|
|
||||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
|
@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase):
|
|||||||
def test_dlx_device_type(self):
|
def test_dlx_device_type(self):
|
||||||
a = mx.array([1, 2, 3])
|
a = mx.array([1, 2, 3])
|
||||||
device_type, device_id = a.__dlpack_device__()
|
device_type, device_id = a.__dlpack_device__()
|
||||||
self.assertIn(device_type, [1, 8])
|
self.assertIn(device_type, [1, 8, 13])
|
||||||
self.assertEqual(device_id, 0)
|
self.assertEqual(device_id, 0)
|
||||||
|
|
||||||
if device_type == 8:
|
if device_type == 8:
|
||||||
|
@ -10,7 +10,7 @@ import mlx_tests
|
|||||||
class TestDefaultDevice(unittest.TestCase):
|
class TestDefaultDevice(unittest.TestCase):
|
||||||
def test_mlx_default_device(self):
|
def test_mlx_default_device(self):
|
||||||
device = mx.default_device()
|
device = mx.default_device()
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
self.assertEqual(device, mx.Device(mx.gpu))
|
self.assertEqual(device, mx.Device(mx.gpu))
|
||||||
self.assertEqual(str(device), "Device(gpu, 0)")
|
self.assertEqual(str(device), "Device(gpu, 0)")
|
||||||
self.assertEqual(device, mx.gpu)
|
self.assertEqual(device, mx.gpu)
|
||||||
@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(s2.device, mx.default_device())
|
self.assertEqual(s2.device, mx.default_device())
|
||||||
self.assertNotEqual(s1, s2)
|
self.assertNotEqual(s1, s2)
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
s_gpu = mx.default_stream(mx.gpu)
|
s_gpu = mx.default_stream(mx.gpu)
|
||||||
self.assertEqual(s_gpu.device, mx.gpu)
|
self.assertEqual(s_gpu.device, mx.gpu)
|
||||||
else:
|
else:
|
||||||
@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
|||||||
s_cpu = mx.new_stream(mx.cpu)
|
s_cpu = mx.new_stream(mx.cpu)
|
||||||
self.assertEqual(s_cpu.device, mx.cpu)
|
self.assertEqual(s_cpu.device, mx.cpu)
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
s_gpu = mx.new_stream(mx.gpu)
|
s_gpu = mx.new_stream(mx.gpu)
|
||||||
self.assertEqual(s_gpu.device, mx.gpu)
|
self.assertEqual(s_gpu.device, mx.gpu)
|
||||||
else:
|
else:
|
||||||
@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
||||||
self.assertEqual(a.item(), b.item())
|
self.assertEqual(a.item(), b.item())
|
||||||
s_gpu = mx.new_stream(mx.gpu)
|
s_gpu = mx.new_stream(mx.gpu)
|
||||||
|
@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||||
|
|
||||||
|
|
||||||
class TestSchedulers(unittest.TestCase):
|
class TestSchedulers(mlx_tests.MLXTestCase):
|
||||||
def test_decay_lr(self):
|
def test_decay_lr(self):
|
||||||
for optim_class in optimizers_dict.values():
|
for optim_class in optimizers_dict.values():
|
||||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user