Pickle: workaround for bfloat16

This commit is contained in:
Luca Arnaboldi 2024-03-04 23:45:17 +01:00
parent 6b6b4f0a5f
commit 25afe6bc09
2 changed files with 19 additions and 7 deletions

View File

@ -776,13 +776,25 @@ void init_array(py::module_& m) {
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
.def(py::pickle( .def(py::pickle(
[](array& a) { // __getstate__ [](array& a) { // __getstate__
return py::array(buffer_info(a)); auto dtype = a.dtype();
}, if (dtype == bfloat16) {
[](py::array npa) { // __setstate__ array b = astype(a, float32);
if (not py::isinstance<py::array>(npa)) { return py::make_tuple(
throw std::runtime_error("Invalid state!"); dtype_to_array_protocol(dtype), py::array(buffer_info(b)));
} else {
return py::make_tuple(
dtype_to_array_protocol(dtype), py::array(buffer_info(a)));
} }
return np_array_to_mlx(npa, std::nullopt);
},
[](py::tuple t) { // __setstate__
if (t.size() != 2 or not py::isinstance<py::array>(t[1]) or
not py::isinstance<py::str>(t[0])) {
throw std::runtime_error("Invalid state passed to __setstate__!");
}
return np_array_to_mlx(
t[1].cast<py::array>(),
dtype_from_array_protocol(t[0].cast<std::string>()));
})) }))
.def("__copy__", [](const array& self) { return array(self); }) .def("__copy__", [](const array& self) { return array(self); })
.def( .def(

View File

@ -671,7 +671,7 @@ class TestArray(mlx_tests.MLXTestCase):
mx.uint64, mx.uint64,
mx.float16, mx.float16,
mx.float32, mx.float32,
# mx.bfloat16, mx.bfloat16,
mx.complex64, mx.complex64,
] ]