allow conversion to dlpack (#1120)

This commit is contained in:
Awni Hannun
2024-05-16 16:11:37 -07:00
committed by GitHub
parent 8b76571896
commit 81dd33af66
4 changed files with 41 additions and 26 deletions

View File

@@ -669,19 +669,14 @@ void init_array(nb::module_& m) {
return a.shape(0);
})
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
.def(
"__getstate__",
[](const array& a) {
if (a.dtype() == bfloat16) {
}
return mlx_to_np_array(a);
})
.def("__getstate__", &mlx_to_np_array)
.def(
"__setstate__",
[](array& arr,
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
new (&arr) array(nd_array_to_mlx(state, std::nullopt));
})
.def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); })
.def("__copy__", [](const array& self) { return array(self); })
.def(
"__deepcopy__",