Support pickling array for bfloat16 (#2586)

* add bfloat16 pickling

* Improvements

* improve

---------

Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
This commit is contained in:
Daniel Yeh
2025-09-23 05:12:15 +02:00
committed by GitHub
parent bf01ad9367
commit fbbf3b9b3e
4 changed files with 36 additions and 12 deletions

View File

@@ -466,12 +466,37 @@ void init_array(nb::module_& m) {
})
.def(
"__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); })
.def("__getstate__", &mlx_to_np_array)
.def(
"__getstate__",
[](const mx::array& a) {
auto nd = (a.dtype() == mx::bfloat16)
? mlx_to_np_array(mx::view(a, mx::uint16))
: mlx_to_np_array(a);
return nb::make_tuple(nd, static_cast<uint8_t>(a.dtype().val()));
})
.def(
"__setstate__",
[](mx::array& arr,
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
new (&arr) mx::array(nd_array_to_mlx(state, std::nullopt));
[](mx::array& arr, const nb::tuple& state) {
if (nb::len(state) != 2) {
throw std::invalid_argument(
"Invalid pickle state: expected (ndarray, Dtype::Val)");
}
using ND = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
ND nd = nb::cast<ND>(state[0]);
auto val = static_cast<mx::Dtype::Val>(nb::cast<uint8_t>(state[1]));
if (val == mx::Dtype::Val::bfloat16) {
auto owner = nb::handle(state[0].ptr());
new (&arr) mx::array(nd_array_to_mlx(
ND(nd.data(),
nd.ndim(),
reinterpret_cast<const size_t*>(nd.shape_ptr()),
owner,
nullptr,
nb::bfloat16),
mx::bfloat16));
} else {
new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt));
}
})
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
.def(

View File

@@ -23,8 +23,6 @@ struct ndarray_traits<mx::float16_t> {
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
int check_shape_dim(int64_t dim) {
@@ -51,6 +49,7 @@ mx::array nd_array_to_mlx(
std::optional<mx::Dtype> dtype) {
// Compute the shape and size
mx::Shape shape;
shape.reserve(nd_array.ndim());
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i)));
}

View File

@@ -12,6 +12,10 @@
namespace mx = mlx::core;
namespace nb = nanobind;
namespace nanobind {
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
struct ArrayLike {
ArrayLike(nb::object obj) : obj(obj) {};
nb::object obj;

View File

@@ -532,7 +532,7 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(str(x), expected)
x = mx.array([[1, 2], [1, 2], [1, 2]])
expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)"
expected = "array([[1, 2],\n [1, 2],\n [1, 2]], dtype=int32)"
self.assertEqual(str(x), expected)
x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])
@@ -886,6 +886,7 @@ class TestArray(mlx_tests.MLXTestCase):
mx.uint64,
mx.float16,
mx.float32,
mx.bfloat16,
mx.complex64,
]
@@ -895,11 +896,6 @@ class TestArray(mlx_tests.MLXTestCase):
y = pickle.loads(state)
self.assertEqualArray(y, x)
# check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(TypeError):
pickle.dumps(x)
def test_array_copy(self):
dtypes = [
mx.int8,