From 95d11bda06050199bceaeb2be1bbcacff2d4be73 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Sun, 23 Jun 2024 05:47:22 -0700 Subject: [PATCH] Fix NumPy 2.0 pickle test (#1221) * fix numpy version <2 temporarily * typo * better fix * Fix just for bfloat16 --------- Co-authored-by: Alex Barron --- python/src/convert.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 5f4cb127d..9c4d71b1b 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -122,7 +122,7 @@ nb::ndarray mlx_to_nd_array_impl( a.data(), a.ndim(), shape.data(), - nb::none(), + /* owner= */ nb::none(), strides.data(), t.value_or(nb::dtype())); } @@ -151,7 +151,8 @@ nb::ndarray mlx_to_nd_array(const array& a) { case float16: return mlx_to_nd_array_impl(a); case bfloat16: - return mlx_to_nd_array_impl(a, nb::bfloat16); + throw nb::type_error( + "bfloat16 arrays cannot be converted directly to NumPy."); case float32: return mlx_to_nd_array_impl(a); case complex64: