mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
feat: Added dlpack device (#1165)
* feat: Added dlpack device * feat: Added device_id to dlpack device * feat: Added device_id to dlpack device * doc: updated conversion docs * doc: updated numpy.rst dlpack information * doc: updated numpy.rst dlpack information * Update docs/src/usage/numpy.rst * Update docs/src/usage/numpy.rst --------- Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local> Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
fd1c08137b
commit
ab977109db
@ -3,7 +3,11 @@
|
|||||||
Conversion to NumPy and Other Frameworks
|
Conversion to NumPy and Other Frameworks
|
||||||
========================================
|
========================================
|
||||||
|
|
||||||
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
MLX array supports conversion between other frameworks with either:
|
||||||
|
|
||||||
|
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||||
|
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||||
|
|
||||||
Let's convert an array to NumPy and back.
|
Let's convert an array to NumPy and back.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -10,11 +10,13 @@
|
|||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "python/src/buffer.h"
|
#include "python/src/buffer.h"
|
||||||
#include "python/src/convert.h"
|
#include "python/src/convert.h"
|
||||||
#include "python/src/indexing.h"
|
#include "python/src/indexing.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/device.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
@ -353,6 +355,17 @@ void init_array(nb::module_& m) {
|
|||||||
new (&arr) array(nd_array_to_mlx(state, std::nullopt));
|
new (&arr) array(nd_array_to_mlx(state, std::nullopt));
|
||||||
})
|
})
|
||||||
.def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); })
|
.def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); })
|
||||||
|
.def(
|
||||||
|
"__dlpack_device__",
|
||||||
|
[](const array& a) {
|
||||||
|
if (metal::is_available()) {
|
||||||
|
// Metal device is available
|
||||||
|
return nb::make_tuple(8, 0);
|
||||||
|
} else {
|
||||||
|
// CPU device
|
||||||
|
return nb::make_tuple(1, 0);
|
||||||
|
}
|
||||||
|
})
|
||||||
.def("__copy__", [](const array& self) { return array(self); })
|
.def("__copy__", [](const array& self) { return array(self); })
|
||||||
.def(
|
.def(
|
||||||
"__deepcopy__",
|
"__deepcopy__",
|
||||||
|
@ -161,6 +161,19 @@ class TestInequality(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(a != b)
|
self.assertTrue(a != b)
|
||||||
self.assertTrue(a != c)
|
self.assertTrue(a != c)
|
||||||
|
|
||||||
|
def test_dlx_device_type(self):
|
||||||
|
a = mx.array([1, 2, 3])
|
||||||
|
device_type, device_id = a.__dlpack_device__()
|
||||||
|
self.assertIn(device_type, [1, 8])
|
||||||
|
self.assertEqual(device_id, 0)
|
||||||
|
|
||||||
|
if device_type == 8:
|
||||||
|
# Additional check if Metal is supposed to be available
|
||||||
|
self.assertTrue(mx.metal.is_available())
|
||||||
|
elif device_type == 1:
|
||||||
|
# Additional check if CPU is the fallback
|
||||||
|
self.assertFalse(mx.metal.is_available())
|
||||||
|
|
||||||
def test_tuple_not_equals_array(self):
|
def test_tuple_not_equals_array(self):
|
||||||
a = mx.array([1, 2, 3])
|
a = mx.array([1, 2, 3])
|
||||||
b = (1, 2, 3)
|
b = (1, 2, 3)
|
||||||
|
Loading…
Reference in New Issue
Block a user