diff --git a/python/src/array.cpp b/python/src/array.cpp index 8bca3ba3e..c491b44e0 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -776,13 +776,25 @@ void init_array(py::module_& m) { .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) .def(py::pickle( [](array& a) { // __getstate__ - return py::array(buffer_info(a)); - }, - [](py::array npa) { // __setstate__ - if (not py::isinstance(npa)) { - throw std::runtime_error("Invalid state!"); + auto dtype = a.dtype(); + if (dtype == bfloat16) { + array b = astype(a, float32); + return py::make_tuple( + dtype_to_array_protocol(dtype), py::array(buffer_info(b))); + } else { + return py::make_tuple( + dtype_to_array_protocol(dtype), py::array(buffer_info(a))); } - return np_array_to_mlx(npa, std::nullopt); + + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 2 or not py::isinstance(t[1]) or + not py::isinstance(t[0])) { + throw std::runtime_error("Invalid state passed to __setstate__!"); + } + return np_array_to_mlx( + t[1].cast(), + dtype_from_array_protocol(t[0].cast())); })) .def("__copy__", [](const array& self) { return array(self); }) .def( diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 4409baa51..52bc6cf12 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -671,7 +671,7 @@ class TestArray(mlx_tests.MLXTestCase): mx.uint64, mx.float16, mx.float32, - # mx.bfloat16, + mx.bfloat16, mx.complex64, ]