mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
allow conversion to dlpack (#1120)
This commit is contained in:
parent
8b76571896
commit
81dd33af66
@ -669,19 +669,14 @@ void init_array(nb::module_& m) {
|
|||||||
return a.shape(0);
|
return a.shape(0);
|
||||||
})
|
})
|
||||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||||
.def(
|
.def("__getstate__", &mlx_to_np_array)
|
||||||
"__getstate__",
|
|
||||||
[](const array& a) {
|
|
||||||
if (a.dtype() == bfloat16) {
|
|
||||||
}
|
|
||||||
return mlx_to_np_array(a);
|
|
||||||
})
|
|
||||||
.def(
|
.def(
|
||||||
"__setstate__",
|
"__setstate__",
|
||||||
[](array& arr,
|
[](array& arr,
|
||||||
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
|
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
|
||||||
new (&arr) array(nd_array_to_mlx(state, std::nullopt));
|
new (&arr) array(nd_array_to_mlx(state, std::nullopt));
|
||||||
})
|
})
|
||||||
|
.def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); })
|
||||||
.def("__copy__", [](const array& self) { return array(self); })
|
.def("__copy__", [](const array& self) { return array(self); })
|
||||||
.def(
|
.def(
|
||||||
"__deepcopy__",
|
"__deepcopy__",
|
||||||
|
@ -100,8 +100,8 @@ array nd_array_to_mlx(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Lib, typename T>
|
template <typename T, typename... NDParams>
|
||||||
nb::ndarray<Lib> mlx_to_nd_array(
|
nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
||||||
array a,
|
array a,
|
||||||
std::optional<nb::dlpack::dtype> t = {}) {
|
std::optional<nb::dlpack::dtype> t = {}) {
|
||||||
{
|
{
|
||||||
@ -110,47 +110,51 @@ nb::ndarray<Lib> mlx_to_nd_array(
|
|||||||
}
|
}
|
||||||
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
|
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
|
||||||
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
|
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
|
||||||
return nb::ndarray<Lib>(
|
return nb::ndarray<NDParams...>(
|
||||||
a.data<T>(),
|
a.data<T>(),
|
||||||
a.ndim(),
|
a.ndim(),
|
||||||
shape.data(),
|
shape.data(),
|
||||||
nb::handle(),
|
nb::none(),
|
||||||
strides.data(),
|
strides.data(),
|
||||||
t.value_or(nb::dtype<T>()));
|
t.value_or(nb::dtype<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Lib>
|
template <typename... NDParams>
|
||||||
nb::ndarray<Lib> mlx_to_nd_array(const array& a) {
|
nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
return mlx_to_nd_array<Lib, bool>(a);
|
return mlx_to_nd_array_impl<bool, NDParams...>(a);
|
||||||
case uint8:
|
case uint8:
|
||||||
return mlx_to_nd_array<Lib, uint8_t>(a);
|
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
|
||||||
case uint16:
|
case uint16:
|
||||||
return mlx_to_nd_array<Lib, uint16_t>(a);
|
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
|
||||||
case uint32:
|
case uint32:
|
||||||
return mlx_to_nd_array<Lib, uint32_t>(a);
|
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
|
||||||
case uint64:
|
case uint64:
|
||||||
return mlx_to_nd_array<Lib, uint64_t>(a);
|
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
|
||||||
case int8:
|
case int8:
|
||||||
return mlx_to_nd_array<Lib, int8_t>(a);
|
return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
|
||||||
case int16:
|
case int16:
|
||||||
return mlx_to_nd_array<Lib, int16_t>(a);
|
return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
|
||||||
case int32:
|
case int32:
|
||||||
return mlx_to_nd_array<Lib, int32_t>(a);
|
return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
|
||||||
case int64:
|
case int64:
|
||||||
return mlx_to_nd_array<Lib, int64_t>(a);
|
return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
|
||||||
case float16:
|
case float16:
|
||||||
return mlx_to_nd_array<Lib, float16_t>(a);
|
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return mlx_to_nd_array<Lib, bfloat16_t>(a, nb::bfloat16);
|
return mlx_to_nd_array_impl<bfloat16_t, NDParams...>(a, nb::bfloat16);
|
||||||
case float32:
|
case float32:
|
||||||
return mlx_to_nd_array<Lib, float>(a);
|
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||||
case complex64:
|
case complex64:
|
||||||
return mlx_to_nd_array<Lib, std::complex<float>>(a);
|
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
|
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
|
||||||
return mlx_to_nd_array<nb::numpy>(a);
|
return mlx_to_nd_array<nb::numpy>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nb::ndarray<> mlx_to_dlpack(const array& a) {
|
||||||
|
return mlx_to_nd_array<>(a);
|
||||||
|
}
|
||||||
|
@ -13,4 +13,6 @@ using namespace mlx::core;
|
|||||||
array nd_array_to_mlx(
|
array nd_array_to_mlx(
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||||
std::optional<Dtype> dtype);
|
std::optional<Dtype> dtype);
|
||||||
|
|
||||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
|
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
|
||||||
|
nb::ndarray<> mlx_to_dlpack(const array& a);
|
||||||
|
@ -1722,6 +1722,20 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(z.dtype, mx.int32)
|
self.assertEqual(z.dtype, mx.int32)
|
||||||
self.assertEqual(z.item(), 3)
|
self.assertEqual(z.item(), 3)
|
||||||
|
|
||||||
|
def test_dlpack(self):
|
||||||
|
x = mx.array(1, dtype=mx.int32)
|
||||||
|
y = np.from_dlpack(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, x))
|
||||||
|
|
||||||
|
x = mx.array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
y = np.from_dlpack(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, x))
|
||||||
|
|
||||||
|
x = mx.arange(16).reshape(4, 4)
|
||||||
|
x = x[::2, ::2]
|
||||||
|
y = np.from_dlpack(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, x))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user