mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 16:28:10 +08:00
Support pickling array for bfloat16 (#2586)
* add bfloat16 pickling * Improvements * improve --------- Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
This commit is contained in:
@@ -466,12 +466,37 @@ void init_array(nb::module_& m) {
|
|||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); })
|
"__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); })
|
||||||
.def("__getstate__", &mlx_to_np_array)
|
.def(
|
||||||
|
"__getstate__",
|
||||||
|
[](const mx::array& a) {
|
||||||
|
auto nd = (a.dtype() == mx::bfloat16)
|
||||||
|
? mlx_to_np_array(mx::view(a, mx::uint16))
|
||||||
|
: mlx_to_np_array(a);
|
||||||
|
return nb::make_tuple(nd, static_cast<uint8_t>(a.dtype().val()));
|
||||||
|
})
|
||||||
.def(
|
.def(
|
||||||
"__setstate__",
|
"__setstate__",
|
||||||
[](mx::array& arr,
|
[](mx::array& arr, const nb::tuple& state) {
|
||||||
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
|
if (nb::len(state) != 2) {
|
||||||
new (&arr) mx::array(nd_array_to_mlx(state, std::nullopt));
|
throw std::invalid_argument(
|
||||||
|
"Invalid pickle state: expected (ndarray, Dtype::Val)");
|
||||||
|
}
|
||||||
|
using ND = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
|
||||||
|
ND nd = nb::cast<ND>(state[0]);
|
||||||
|
auto val = static_cast<mx::Dtype::Val>(nb::cast<uint8_t>(state[1]));
|
||||||
|
if (val == mx::Dtype::Val::bfloat16) {
|
||||||
|
auto owner = nb::handle(state[0].ptr());
|
||||||
|
new (&arr) mx::array(nd_array_to_mlx(
|
||||||
|
ND(nd.data(),
|
||||||
|
nd.ndim(),
|
||||||
|
reinterpret_cast<const size_t*>(nd.shape_ptr()),
|
||||||
|
owner,
|
||||||
|
nullptr,
|
||||||
|
nb::bfloat16),
|
||||||
|
mx::bfloat16));
|
||||||
|
} else {
|
||||||
|
new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt));
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
|
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
|
||||||
.def(
|
.def(
|
||||||
|
@@ -23,8 +23,6 @@ struct ndarray_traits<mx::float16_t> {
|
|||||||
static constexpr bool is_int = false;
|
static constexpr bool is_int = false;
|
||||||
static constexpr bool is_signed = true;
|
static constexpr bool is_signed = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
|
||||||
}; // namespace nanobind
|
}; // namespace nanobind
|
||||||
|
|
||||||
int check_shape_dim(int64_t dim) {
|
int check_shape_dim(int64_t dim) {
|
||||||
@@ -51,6 +49,7 @@ mx::array nd_array_to_mlx(
|
|||||||
std::optional<mx::Dtype> dtype) {
|
std::optional<mx::Dtype> dtype) {
|
||||||
// Compute the shape and size
|
// Compute the shape and size
|
||||||
mx::Shape shape;
|
mx::Shape shape;
|
||||||
|
shape.reserve(nd_array.ndim());
|
||||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||||
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||||
}
|
}
|
||||||
|
@@ -12,6 +12,10 @@
|
|||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
|
||||||
|
namespace nanobind {
|
||||||
|
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
||||||
|
}; // namespace nanobind
|
||||||
|
|
||||||
struct ArrayLike {
|
struct ArrayLike {
|
||||||
ArrayLike(nb::object obj) : obj(obj) {};
|
ArrayLike(nb::object obj) : obj(obj) {};
|
||||||
nb::object obj;
|
nb::object obj;
|
||||||
|
@@ -532,7 +532,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(str(x), expected)
|
self.assertEqual(str(x), expected)
|
||||||
|
|
||||||
x = mx.array([[1, 2], [1, 2], [1, 2]])
|
x = mx.array([[1, 2], [1, 2], [1, 2]])
|
||||||
expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)"
|
expected = "array([[1, 2],\n [1, 2],\n [1, 2]], dtype=int32)"
|
||||||
self.assertEqual(str(x), expected)
|
self.assertEqual(str(x), expected)
|
||||||
|
|
||||||
x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])
|
x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])
|
||||||
@@ -886,6 +886,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
mx.uint64,
|
mx.uint64,
|
||||||
mx.float16,
|
mx.float16,
|
||||||
mx.float32,
|
mx.float32,
|
||||||
|
mx.bfloat16,
|
||||||
mx.complex64,
|
mx.complex64,
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -895,11 +896,6 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
y = pickle.loads(state)
|
y = pickle.loads(state)
|
||||||
self.assertEqualArray(y, x)
|
self.assertEqualArray(y, x)
|
||||||
|
|
||||||
# check if it throws an error when dtype is not supported (bfloat16)
|
|
||||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
pickle.dumps(x)
|
|
||||||
|
|
||||||
def test_array_copy(self):
|
def test_array_copy(self):
|
||||||
dtypes = [
|
dtypes = [
|
||||||
mx.int8,
|
mx.int8,
|
||||||
|
Reference in New Issue
Block a user