mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 22:11:15 +08:00
Fix NumPy 2.0 pickle test (#1221)
* fix numpy version <2 temporarily * typo * better fix * Fix just for bfloat16 --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
parent
af9079cc1f
commit
95d11bda06
@ -122,7 +122,7 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
|||||||
a.data<T>(),
|
a.data<T>(),
|
||||||
a.ndim(),
|
a.ndim(),
|
||||||
shape.data(),
|
shape.data(),
|
||||||
nb::none(),
|
/* owner= */ nb::none(),
|
||||||
strides.data(),
|
strides.data(),
|
||||||
t.value_or(nb::dtype<T>()));
|
t.value_or(nb::dtype<T>()));
|
||||||
}
|
}
|
||||||
@ -151,7 +151,8 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
|
|||||||
case float16:
|
case float16:
|
||||||
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return mlx_to_nd_array_impl<bfloat16_t, NDParams...>(a, nb::bfloat16);
|
throw nb::type_error(
|
||||||
|
"bfloat16 arrays cannot be converted directly to NumPy.");
|
||||||
case float32:
|
case float32:
|
||||||
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||||
case complex64:
|
case complex64:
|
||||||
|
Loading…
Reference in New Issue
Block a user