mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 21:21:16 +08:00
Fix NumPy 2.0 pickle test (#1221)
* fix numpy version <2 temporarily * typo * better fix * Fix just for bfloat16 --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
parent
af9079cc1f
commit
95d11bda06
@ -122,7 +122,7 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
||||
a.data<T>(),
|
||||
a.ndim(),
|
||||
shape.data(),
|
||||
nb::none(),
|
||||
/* owner= */ nb::none(),
|
||||
strides.data(),
|
||||
t.value_or(nb::dtype<T>()));
|
||||
}
|
||||
@ -151,7 +151,8 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
|
||||
case float16:
|
||||
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
||||
case bfloat16:
|
||||
return mlx_to_nd_array_impl<bfloat16_t, NDParams...>(a, nb::bfloat16);
|
||||
throw nb::type_error(
|
||||
"bfloat16 arrays cannot be converted directly to NumPy.");
|
||||
case float32:
|
||||
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||
case complex64:
|
||||
|
Loading…
Reference in New Issue
Block a user