Fix NumPy 2.0 pickle test (#1221)

* fix numpy version <2 temporarily

* typo

* better fix

* Fix just for bfloat16

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron 2024-06-23 05:47:22 -07:00 committed by GitHub
parent af9079cc1f
commit 95d11bda06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -122,7 +122,7 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
a.data<T>(),
a.ndim(),
shape.data(),
nb::none(),
/* owner= */ nb::none(),
strides.data(),
t.value_or(nb::dtype<T>()));
}
@ -151,7 +151,8 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
case float16:
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
case bfloat16:
return mlx_to_nd_array_impl<bfloat16_t, NDParams...>(a, nb::bfloat16);
throw nb::type_error(
"bfloat16 arrays cannot be converted directly to NumPy.");
case float32:
return mlx_to_nd_array_impl<float, NDParams...>(a);
case complex64: