mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Support pickling array for bfloat16 (#2586)
* add bfloat16 pickling * Improvements * improve --------- Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
This commit is contained in:
@@ -23,8 +23,6 @@ struct ndarray_traits<mx::float16_t> {
|
||||
static constexpr bool is_int = false;
|
||||
static constexpr bool is_signed = true;
|
||||
};
|
||||
|
||||
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
||||
}; // namespace nanobind
|
||||
|
||||
int check_shape_dim(int64_t dim) {
|
||||
@@ -51,6 +49,7 @@ mx::array nd_array_to_mlx(
|
||||
std::optional<mx::Dtype> dtype) {
|
||||
// Compute the shape and size
|
||||
mx::Shape shape;
|
||||
shape.reserve(nd_array.ndim());
|
||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user