mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Pickle: workaround for bfloat16
This commit is contained in:
parent
6b6b4f0a5f
commit
25afe6bc09
@ -776,13 +776,25 @@ void init_array(py::module_& m) {
|
|||||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](array& a) { // __getstate__
|
[](array& a) { // __getstate__
|
||||||
return py::array(buffer_info(a));
|
auto dtype = a.dtype();
|
||||||
},
|
if (dtype == bfloat16) {
|
||||||
[](py::array npa) { // __setstate__
|
array b = astype(a, float32);
|
||||||
if (not py::isinstance<py::array>(npa)) {
|
return py::make_tuple(
|
||||||
throw std::runtime_error("Invalid state!");
|
dtype_to_array_protocol(dtype), py::array(buffer_info(b)));
|
||||||
|
} else {
|
||||||
|
return py::make_tuple(
|
||||||
|
dtype_to_array_protocol(dtype), py::array(buffer_info(a)));
|
||||||
}
|
}
|
||||||
return np_array_to_mlx(npa, std::nullopt);
|
|
||||||
|
},
|
||||||
|
[](py::tuple t) { // __setstate__
|
||||||
|
if (t.size() != 2 or not py::isinstance<py::array>(t[1]) or
|
||||||
|
not py::isinstance<py::str>(t[0])) {
|
||||||
|
throw std::runtime_error("Invalid state passed to __setstate__!");
|
||||||
|
}
|
||||||
|
return np_array_to_mlx(
|
||||||
|
t[1].cast<py::array>(),
|
||||||
|
dtype_from_array_protocol(t[0].cast<std::string>()));
|
||||||
}))
|
}))
|
||||||
.def("__copy__", [](const array& self) { return array(self); })
|
.def("__copy__", [](const array& self) { return array(self); })
|
||||||
.def(
|
.def(
|
||||||
|
@ -671,7 +671,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
mx.uint64,
|
mx.uint64,
|
||||||
mx.float16,
|
mx.float16,
|
||||||
mx.float32,
|
mx.float32,
|
||||||
# mx.bfloat16,
|
mx.bfloat16,
|
||||||
mx.complex64,
|
mx.complex64,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user