From 81dd33af6696fa2f68572785f0c38046ed7600a6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 16 May 2024 16:11:37 -0700 Subject: [PATCH] allow conversion to dlpack (#1120) --- python/src/array.cpp | 9 ++------ python/src/convert.cpp | 42 +++++++++++++++++++++----------------- python/src/convert.h | 2 ++ python/tests/test_array.py | 14 +++++++++++++ 4 files changed, 41 insertions(+), 26 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 943ce2a5a..8adf7d3fd 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -669,19 +669,14 @@ void init_array(nb::module_& m) { return a.shape(0); }) .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) - .def( - "__getstate__", - [](const array& a) { - if (a.dtype() == bfloat16) { - } - return mlx_to_np_array(a); - }) + .def("__getstate__", &mlx_to_np_array) .def( "__setstate__", [](array& arr, const nb::ndarray& state) { new (&arr) array(nd_array_to_mlx(state, std::nullopt)); }) + .def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); }) .def("__copy__", [](const array& self) { return array(self); }) .def( "__deepcopy__", diff --git a/python/src/convert.cpp b/python/src/convert.cpp index cee16da80..c84168c09 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -100,8 +100,8 @@ array nd_array_to_mlx( } } -template -nb::ndarray mlx_to_nd_array( +template +nb::ndarray mlx_to_nd_array_impl( array a, std::optional t = {}) { { @@ -110,47 +110,51 @@ nb::ndarray mlx_to_nd_array( } std::vector shape(a.shape().begin(), a.shape().end()); std::vector strides(a.strides().begin(), a.strides().end()); - return nb::ndarray( + return nb::ndarray( a.data(), a.ndim(), shape.data(), - nb::handle(), + nb::none(), strides.data(), t.value_or(nb::dtype())); } -template -nb::ndarray mlx_to_nd_array(const array& a) { +template +nb::ndarray mlx_to_nd_array(const array& a) { switch (a.dtype()) { case bool_: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case uint8: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case uint16: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case uint32: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case uint64: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case int8: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case int16: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case int32: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case int64: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case float16: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case bfloat16: - return mlx_to_nd_array(a, nb::bfloat16); + return mlx_to_nd_array_impl(a, nb::bfloat16); case float32: - return mlx_to_nd_array(a); + return mlx_to_nd_array_impl(a); case complex64: - return mlx_to_nd_array>(a); + return mlx_to_nd_array_impl, NDParams...>(a); } } nb::ndarray mlx_to_np_array(const array& a) { return mlx_to_nd_array(a); } + +nb::ndarray<> mlx_to_dlpack(const array& a) { + return mlx_to_nd_array<>(a); +} diff --git a/python/src/convert.h b/python/src/convert.h index 36e868a7d..a28f24da6 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -13,4 +13,6 @@ using namespace mlx::core; array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype); + nb::ndarray mlx_to_np_array(const array& a); +nb::ndarray<> mlx_to_dlpack(const array& a); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index f79fad73e..e3ef4d88f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1722,6 +1722,20 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual(z.dtype, mx.int32) self.assertEqual(z.item(), 3) + def test_dlpack(self): + x = mx.array(1, dtype=mx.int32) + y = np.from_dlpack(x) + self.assertTrue(mx.array_equal(y, x)) + + x = mx.array([[1.0, 2.0], [3.0, 4.0]]) + y = np.from_dlpack(x) + self.assertTrue(mx.array_equal(y, x)) + + x = mx.arange(16).reshape(4, 4) + x = x[::2, ::2] + y = np.from_dlpack(x) + self.assertTrue(mx.array_equal(y, x)) + if __name__ == "__main__": unittest.main()