Support pickling array for bfloat16 (#2586)

* add bfloat16 pickling

* Improvements

* improve

---------

Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
This commit is contained in:
Daniel Yeh
2025-09-23 05:12:15 +02:00
committed by GitHub
parent bf01ad9367
commit fbbf3b9b3e
4 changed files with 36 additions and 12 deletions

View File

@@ -23,8 +23,6 @@ struct ndarray_traits<mx::float16_t> {
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
int check_shape_dim(int64_t dim) {
@@ -51,6 +49,7 @@ mx::array nd_array_to_mlx(
std::optional<mx::Dtype> dtype) {
// Compute the shape and size
mx::Shape shape;
shape.reserve(nd_array.ndim());
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i)));
}