mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 16:28:10 +08:00
Support pickling array for bfloat16 (#2586)
* add bfloat16 pickling * Improvements * improve --------- Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
This commit is contained in:
@@ -466,12 +466,37 @@ void init_array(nb::module_& m) {
|
||||
})
|
||||
.def(
|
||||
"__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); })
|
||||
.def("__getstate__", &mlx_to_np_array)
|
||||
.def(
|
||||
"__getstate__",
|
||||
[](const mx::array& a) {
|
||||
auto nd = (a.dtype() == mx::bfloat16)
|
||||
? mlx_to_np_array(mx::view(a, mx::uint16))
|
||||
: mlx_to_np_array(a);
|
||||
return nb::make_tuple(nd, static_cast<uint8_t>(a.dtype().val()));
|
||||
})
|
||||
.def(
|
||||
"__setstate__",
|
||||
[](mx::array& arr,
|
||||
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
|
||||
new (&arr) mx::array(nd_array_to_mlx(state, std::nullopt));
|
||||
[](mx::array& arr, const nb::tuple& state) {
|
||||
if (nb::len(state) != 2) {
|
||||
throw std::invalid_argument(
|
||||
"Invalid pickle state: expected (ndarray, Dtype::Val)");
|
||||
}
|
||||
using ND = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
|
||||
ND nd = nb::cast<ND>(state[0]);
|
||||
auto val = static_cast<mx::Dtype::Val>(nb::cast<uint8_t>(state[1]));
|
||||
if (val == mx::Dtype::Val::bfloat16) {
|
||||
auto owner = nb::handle(state[0].ptr());
|
||||
new (&arr) mx::array(nd_array_to_mlx(
|
||||
ND(nd.data(),
|
||||
nd.ndim(),
|
||||
reinterpret_cast<const size_t*>(nd.shape_ptr()),
|
||||
owner,
|
||||
nullptr,
|
||||
nb::bfloat16),
|
||||
mx::bfloat16));
|
||||
} else {
|
||||
new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt));
|
||||
}
|
||||
})
|
||||
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
|
||||
.def(
|
||||
|
@@ -23,8 +23,6 @@ struct ndarray_traits<mx::float16_t> {
|
||||
static constexpr bool is_int = false;
|
||||
static constexpr bool is_signed = true;
|
||||
};
|
||||
|
||||
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
||||
}; // namespace nanobind
|
||||
|
||||
int check_shape_dim(int64_t dim) {
|
||||
@@ -51,6 +49,7 @@ mx::array nd_array_to_mlx(
|
||||
std::optional<mx::Dtype> dtype) {
|
||||
// Compute the shape and size
|
||||
mx::Shape shape;
|
||||
shape.reserve(nd_array.ndim());
|
||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||
}
|
||||
|
@@ -12,6 +12,10 @@
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace nanobind {
|
||||
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
||||
}; // namespace nanobind
|
||||
|
||||
struct ArrayLike {
|
||||
ArrayLike(nb::object obj) : obj(obj) {};
|
||||
nb::object obj;
|
||||
|
@@ -532,7 +532,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(str(x), expected)
|
||||
|
||||
x = mx.array([[1, 2], [1, 2], [1, 2]])
|
||||
expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)"
|
||||
expected = "array([[1, 2],\n [1, 2],\n [1, 2]], dtype=int32)"
|
||||
self.assertEqual(str(x), expected)
|
||||
|
||||
x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])
|
||||
@@ -886,6 +886,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
mx.uint64,
|
||||
mx.float16,
|
||||
mx.float32,
|
||||
mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
|
||||
@@ -895,11 +896,6 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
y = pickle.loads(state)
|
||||
self.assertEqualArray(y, x)
|
||||
|
||||
# check if it throws an error when dtype is not supported (bfloat16)
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
|
||||
with self.assertRaises(TypeError):
|
||||
pickle.dumps(x)
|
||||
|
||||
def test_array_copy(self):
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
|
Reference in New Issue
Block a user