mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
* Implemented pickling and copy for Python arrays(#300 & #367)
* Fixing typos
* Pickle with NumPy arrays
* Pickle: workaround for bfloat16
* Revert "Pickle: workaround for bfloat16"
This reverts commit 25afe6bc09
.
* Added an error when pickling bfloat16
* Update python/tests/test_array.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/tests/test_array.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/array.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/array.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* clang-format applied
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -458,6 +458,21 @@ std::vector<size_t> buffer_strides(const array& a) {
|
||||
return py_strides;
|
||||
}
|
||||
|
||||
py::buffer_info buffer_info(array& a) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
py::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
}
|
||||
return pybind11::buffer_info(
|
||||
a.data<void>(),
|
||||
a.itemsize(),
|
||||
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
|
||||
// std::string which can't be null
|
||||
a.ndim(),
|
||||
a.shape(),
|
||||
buffer_strides(a));
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Module
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -647,21 +662,7 @@ void init_array(py::module_& m) {
|
||||
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
|
||||
|
||||
array_class
|
||||
.def_buffer([](array& a) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
py::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
}
|
||||
return pybind11::buffer_info(
|
||||
a.data<void>(),
|
||||
a.itemsize(),
|
||||
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
|
||||
// std::string which can't be null
|
||||
a.ndim(),
|
||||
a.shape(),
|
||||
buffer_strides(a));
|
||||
})
|
||||
.def_buffer([](array& a) -> py::buffer_info { return buffer_info(a); })
|
||||
.def_property_readonly(
|
||||
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
||||
.def_property_readonly(
|
||||
@@ -773,6 +774,26 @@ void init_array(py::module_& m) {
|
||||
return a.shape(0);
|
||||
})
|
||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||
.def(py::pickle(
|
||||
[](array& a) { // __getstate__
|
||||
if (a.dtype() == bfloat16) {
|
||||
throw std::runtime_error(
|
||||
"[array.__getstate__] Not supported for bfloat16.");
|
||||
}
|
||||
return py::array(buffer_info(a));
|
||||
},
|
||||
[](py::array npa) { // __setstate__
|
||||
if (not py::isinstance<py::array>(npa)) {
|
||||
throw std::runtime_error(
|
||||
"[array.__setstate__] Received invalid state.");
|
||||
}
|
||||
return np_array_to_mlx(npa, std::nullopt);
|
||||
}))
|
||||
.def("__copy__", [](const array& self) { return array(self); })
|
||||
.def(
|
||||
"__deepcopy__",
|
||||
[](const array& self, py::dict) { return array(self); },
|
||||
"memo"_a)
|
||||
.def(
|
||||
"__add__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
|
Reference in New Issue
Block a user