From ab977109dbc77713b1289f724b3840c4f9c163e5 Mon Sep 17 00:00:00 2001 From: K Venkat Ramnan Date: Fri, 31 May 2024 12:29:01 -0700 Subject: [PATCH] feat: Added dlpack device (#1165) * feat: Added dlpack device * feat: Added device_id to dlpack device * feat: Added device_id to dlpack device * doc: updated conversion docs * doc: updated numpy.rst dlpack information * doc: updated numpy.rst dlpack information * Update docs/src/usage/numpy.rst * Update docs/src/usage/numpy.rst --------- Co-authored-by: Venkat Ramnan Kalyanakumar Co-authored-by: Awni Hannun --- docs/src/usage/numpy.rst | 6 +++++- python/src/array.cpp | 13 +++++++++++++ python/tests/test_array.py | 13 +++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index 1ed801454..6edb94b8b 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -3,7 +3,11 @@ Conversion to NumPy and Other Frameworks ======================================== -MLX array implements the `Python Buffer Protocol `_. +MLX array supports conversion between other frameworks with either: + +* The `Python Buffer Protocol `_. +* `DLPack `_. + Let's convert an array to NumPy and back. .. code-block:: python diff --git a/python/src/array.cpp b/python/src/array.cpp index 964eb6a5d..15825a926 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -10,11 +10,13 @@ #include #include +#include "mlx/backend/metal/metal.h" #include "python/src/buffer.h" #include "python/src/convert.h" #include "python/src/indexing.h" #include "python/src/utils.h" +#include "mlx/device.h" #include "mlx/ops.h" #include "mlx/transforms.h" #include "mlx/utils.h" @@ -353,6 +355,17 @@ void init_array(nb::module_& m) { new (&arr) array(nd_array_to_mlx(state, std::nullopt)); }) .def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); }) + .def( + "__dlpack_device__", + [](const array& a) { + if (metal::is_available()) { + // Metal device is available + return nb::make_tuple(8, 0); + } else { + // CPU device + return nb::make_tuple(1, 0); + } + }) .def("__copy__", [](const array& self) { return array(self); }) .def( "__deepcopy__", diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 5a87ea88a..67e679c35 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -161,6 +161,19 @@ class TestInequality(mlx_tests.MLXTestCase): self.assertTrue(a != b) self.assertTrue(a != c) + def test_dlx_device_type(self): + a = mx.array([1, 2, 3]) + device_type, device_id = a.__dlpack_device__() + self.assertIn(device_type, [1, 8]) + self.assertEqual(device_id, 0) + + if device_type == 8: + # Additional check if Metal is supposed to be available + self.assertTrue(mx.metal.is_available()) + elif device_type == 1: + # Additional check if CPU is the fallback + self.assertFalse(mx.metal.is_available()) + def test_tuple_not_equals_array(self): a = mx.array([1, 2, 3]) b = (1, 2, 3)