mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-10 14:09:19 +08:00
start cuda circle config (#2256)
* rebase * fix metal kernel linking issue on cuda * start cuda circle config
This commit is contained in:
@@ -17,10 +17,7 @@
|
||||
#include "python/src/indexing.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
@@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__dlpack_device__",
|
||||
[](const mx::array& a) {
|
||||
// See
|
||||
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
|
||||
if (mx::metal::is_available()) {
|
||||
// Metal device is available
|
||||
return nb::make_tuple(8, 0);
|
||||
} else if (mx::cu::is_available()) {
|
||||
return nb::make_tuple(13, 0);
|
||||
} else {
|
||||
// CPU device
|
||||
return nb::make_tuple(1, 0);
|
||||
|
||||
@@ -58,4 +58,9 @@ void init_device(nb::module_& m) {
|
||||
&mx::set_default_device,
|
||||
"device"_a,
|
||||
R"pbdoc(Set the default device.)pbdoc");
|
||||
m.def(
|
||||
"is_available",
|
||||
&mx::is_available,
|
||||
"device"_a,
|
||||
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase):
|
||||
def test_dlx_device_type(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
device_type, device_id = a.__dlpack_device__()
|
||||
self.assertIn(device_type, [1, 8])
|
||||
self.assertIn(device_type, [1, 8, 13])
|
||||
self.assertEqual(device_id, 0)
|
||||
|
||||
if device_type == 8:
|
||||
|
||||
@@ -10,7 +10,7 @@ import mlx_tests
|
||||
class TestDefaultDevice(unittest.TestCase):
|
||||
def test_mlx_default_device(self):
|
||||
device = mx.default_device()
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
self.assertEqual(device, mx.Device(mx.gpu))
|
||||
self.assertEqual(str(device), "Device(gpu, 0)")
|
||||
self.assertEqual(device, mx.gpu)
|
||||
@@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(s2.device, mx.default_device())
|
||||
self.assertNotEqual(s1, s2)
|
||||
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
s_gpu = mx.default_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
@@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
s_cpu = mx.new_stream(mx.cpu)
|
||||
self.assertEqual(s_cpu.device, mx.cpu)
|
||||
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
@@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
|
||||
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
||||
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
||||
self.assertEqual(a.item(), b.item())
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
|
||||
@@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||
|
||||
|
||||
class TestSchedulers(unittest.TestCase):
|
||||
class TestSchedulers(mlx_tests.MLXTestCase):
|
||||
def test_decay_lr(self):
|
||||
for optim_class in optimizers_dict.values():
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
||||
|
||||
Reference in New Issue
Block a user