From fbbf3b9b3e476f2e6e0f0c317635eafb13b321d4 Mon Sep 17 00:00:00 2001 From: Daniel Yeh <46629671+Dan-Yeh@users.noreply.github.com> Date: Tue, 23 Sep 2025 05:12:15 +0200 Subject: [PATCH] Support pickling array for bfloat16 (#2586) * add bfloat16 pickling * Improvements * improve --------- Co-authored-by: Chen-Chen Yeh --- python/src/array.cpp | 33 +++++++++++++++++++++++++++++---- python/src/convert.cpp | 3 +-- python/src/convert.h | 4 ++++ python/tests/test_array.py | 8 ++------ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index ae38fa211..9367d4e09 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -466,12 +466,37 @@ void init_array(nb::module_& m) { }) .def( "__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); }) - .def("__getstate__", &mlx_to_np_array) + .def( + "__getstate__", + [](const mx::array& a) { + auto nd = (a.dtype() == mx::bfloat16) + ? mlx_to_np_array(mx::view(a, mx::uint16)) + : mlx_to_np_array(a); + return nb::make_tuple(nd, static_cast(a.dtype().val())); + }) .def( "__setstate__", - [](mx::array& arr, - const nb::ndarray& state) { - new (&arr) mx::array(nd_array_to_mlx(state, std::nullopt)); + [](mx::array& arr, const nb::tuple& state) { + if (nb::len(state) != 2) { + throw std::invalid_argument( + "Invalid pickle state: expected (ndarray, Dtype::Val)"); + } + using ND = nb::ndarray; + ND nd = nb::cast(state[0]); + auto val = static_cast(nb::cast(state[1])); + if (val == mx::Dtype::Val::bfloat16) { + auto owner = nb::handle(state[0].ptr()); + new (&arr) mx::array(nd_array_to_mlx( + ND(nd.data(), + nd.ndim(), + reinterpret_cast(nd.shape_ptr()), + owner, + nullptr, + nb::bfloat16), + mx::bfloat16)); + } else { + new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt)); + } }) .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) .def( diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 1340b663a..88da37103 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -23,8 +23,6 @@ struct ndarray_traits { static constexpr bool is_int = false; static constexpr bool is_signed = true; }; - -static constexpr dlpack::dtype bfloat16{4, 16, 1}; }; // namespace nanobind int check_shape_dim(int64_t dim) { @@ -51,6 +49,7 @@ mx::array nd_array_to_mlx( std::optional dtype) { // Compute the shape and size mx::Shape shape; + shape.reserve(nd_array.ndim()); for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } diff --git a/python/src/convert.h b/python/src/convert.h index f5016c8af..4f8a77abe 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -12,6 +12,10 @@ namespace mx = mlx::core; namespace nb = nanobind; +namespace nanobind { +static constexpr dlpack::dtype bfloat16{4, 16, 1}; +}; // namespace nanobind + struct ArrayLike { ArrayLike(nb::object obj) : obj(obj) {}; nb::object obj; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index ae1cb784f..e932382b1 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -532,7 +532,7 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual(str(x), expected) x = mx.array([[1, 2], [1, 2], [1, 2]]) - expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)" + expected = "array([[1, 2],\n [1, 2],\n [1, 2]], dtype=int32)" self.assertEqual(str(x), expected) x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]]) @@ -886,6 +886,7 @@ class TestArray(mlx_tests.MLXTestCase): mx.uint64, mx.float16, mx.float32, + mx.bfloat16, mx.complex64, ] @@ -895,11 +896,6 @@ class TestArray(mlx_tests.MLXTestCase): y = pickle.loads(state) self.assertEqualArray(y, x) - # check if it throws an error when dtype is not supported (bfloat16) - x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16) - with self.assertRaises(TypeError): - pickle.dumps(x) - def test_array_copy(self): dtypes = [ mx.int8,