mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
allow conversion to dlpack (#1120)
This commit is contained in:
@@ -669,19 +669,14 @@ void init_array(nb::module_& m) {
|
||||
return a.shape(0);
|
||||
})
|
||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||
.def(
|
||||
"__getstate__",
|
||||
[](const array& a) {
|
||||
if (a.dtype() == bfloat16) {
|
||||
}
|
||||
return mlx_to_np_array(a);
|
||||
})
|
||||
.def("__getstate__", &mlx_to_np_array)
|
||||
.def(
|
||||
"__setstate__",
|
||||
[](array& arr,
|
||||
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
|
||||
new (&arr) array(nd_array_to_mlx(state, std::nullopt));
|
||||
})
|
||||
.def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); })
|
||||
.def("__copy__", [](const array& self) { return array(self); })
|
||||
.def(
|
||||
"__deepcopy__",
|
||||
|
||||
Reference in New Issue
Block a user