mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Pickle: workaround for bfloat16
This commit is contained in:
parent
6b6b4f0a5f
commit
25afe6bc09
@ -776,13 +776,25 @@ void init_array(py::module_& m) {
|
||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||
.def(py::pickle(
|
||||
[](array& a) { // __getstate__
|
||||
return py::array(buffer_info(a));
|
||||
},
|
||||
[](py::array npa) { // __setstate__
|
||||
if (not py::isinstance<py::array>(npa)) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
auto dtype = a.dtype();
|
||||
if (dtype == bfloat16) {
|
||||
array b = astype(a, float32);
|
||||
return py::make_tuple(
|
||||
dtype_to_array_protocol(dtype), py::array(buffer_info(b)));
|
||||
} else {
|
||||
return py::make_tuple(
|
||||
dtype_to_array_protocol(dtype), py::array(buffer_info(a)));
|
||||
}
|
||||
return np_array_to_mlx(npa, std::nullopt);
|
||||
|
||||
},
|
||||
[](py::tuple t) { // __setstate__
|
||||
if (t.size() != 2 or not py::isinstance<py::array>(t[1]) or
|
||||
not py::isinstance<py::str>(t[0])) {
|
||||
throw std::runtime_error("Invalid state passed to __setstate__!");
|
||||
}
|
||||
return np_array_to_mlx(
|
||||
t[1].cast<py::array>(),
|
||||
dtype_from_array_protocol(t[0].cast<std::string>()));
|
||||
}))
|
||||
.def("__copy__", [](const array& self) { return array(self); })
|
||||
.def(
|
||||
|
@ -671,7 +671,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
mx.uint64,
|
||||
mx.float16,
|
||||
mx.float32,
|
||||
# mx.bfloat16,
|
||||
mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user