Support pickling array for bfloat16 (#2586)

* add bfloat16 pickling

* Improvements

* improve

---------

Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
This commit is contained in:
Daniel Yeh
2025-09-23 05:12:15 +02:00
committed by GitHub
parent bf01ad9367
commit fbbf3b9b3e
4 changed files with 36 additions and 12 deletions

View File

@@ -466,12 +466,37 @@ void init_array(nb::module_& m) {
}) })
.def( .def(
"__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); }) "__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); })
.def("__getstate__", &mlx_to_np_array) .def(
"__getstate__",
[](const mx::array& a) {
auto nd = (a.dtype() == mx::bfloat16)
? mlx_to_np_array(mx::view(a, mx::uint16))
: mlx_to_np_array(a);
return nb::make_tuple(nd, static_cast<uint8_t>(a.dtype().val()));
})
.def( .def(
"__setstate__", "__setstate__",
[](mx::array& arr, [](mx::array& arr, const nb::tuple& state) {
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) { if (nb::len(state) != 2) {
new (&arr) mx::array(nd_array_to_mlx(state, std::nullopt)); throw std::invalid_argument(
"Invalid pickle state: expected (ndarray, Dtype::Val)");
}
using ND = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
ND nd = nb::cast<ND>(state[0]);
auto val = static_cast<mx::Dtype::Val>(nb::cast<uint8_t>(state[1]));
if (val == mx::Dtype::Val::bfloat16) {
auto owner = nb::handle(state[0].ptr());
new (&arr) mx::array(nd_array_to_mlx(
ND(nd.data(),
nd.ndim(),
reinterpret_cast<const size_t*>(nd.shape_ptr()),
owner,
nullptr,
nb::bfloat16),
mx::bfloat16));
} else {
new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt));
}
}) })
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
.def( .def(

View File

@@ -23,8 +23,6 @@ struct ndarray_traits<mx::float16_t> {
static constexpr bool is_int = false; static constexpr bool is_int = false;
static constexpr bool is_signed = true; static constexpr bool is_signed = true;
}; };
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind }; // namespace nanobind
int check_shape_dim(int64_t dim) { int check_shape_dim(int64_t dim) {
@@ -51,6 +49,7 @@ mx::array nd_array_to_mlx(
std::optional<mx::Dtype> dtype) { std::optional<mx::Dtype> dtype) {
// Compute the shape and size // Compute the shape and size
mx::Shape shape; mx::Shape shape;
shape.reserve(nd_array.ndim());
for (int i = 0; i < nd_array.ndim(); i++) { for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i))); shape.push_back(check_shape_dim(nd_array.shape(i)));
} }

View File

@@ -12,6 +12,10 @@
namespace mx = mlx::core; namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
namespace nanobind {
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
struct ArrayLike { struct ArrayLike {
ArrayLike(nb::object obj) : obj(obj) {}; ArrayLike(nb::object obj) : obj(obj) {};
nb::object obj; nb::object obj;

View File

@@ -532,7 +532,7 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(str(x), expected) self.assertEqual(str(x), expected)
x = mx.array([[1, 2], [1, 2], [1, 2]]) x = mx.array([[1, 2], [1, 2], [1, 2]])
expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)" expected = "array([[1, 2],\n [1, 2],\n [1, 2]], dtype=int32)"
self.assertEqual(str(x), expected) self.assertEqual(str(x), expected)
x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]]) x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])
@@ -886,6 +886,7 @@ class TestArray(mlx_tests.MLXTestCase):
mx.uint64, mx.uint64,
mx.float16, mx.float16,
mx.float32, mx.float32,
mx.bfloat16,
mx.complex64, mx.complex64,
] ]
@@ -895,11 +896,6 @@ class TestArray(mlx_tests.MLXTestCase):
y = pickle.loads(state) y = pickle.loads(state)
self.assertEqualArray(y, x) self.assertEqualArray(y, x)
# check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(TypeError):
pickle.dumps(x)
def test_array_copy(self): def test_array_copy(self):
dtypes = [ dtypes = [
mx.int8, mx.int8,