mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Fix NumPy 2.0 pickle test (#1221)
* fix numpy version <2 temporarily * typo * better fix * Fix just for bfloat16 --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
		@@ -122,7 +122,7 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
 | 
			
		||||
      a.data<T>(),
 | 
			
		||||
      a.ndim(),
 | 
			
		||||
      shape.data(),
 | 
			
		||||
      nb::none(),
 | 
			
		||||
      /* owner= */ nb::none(),
 | 
			
		||||
      strides.data(),
 | 
			
		||||
      t.value_or(nb::dtype<T>()));
 | 
			
		||||
}
 | 
			
		||||
@@ -151,7 +151,8 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
 | 
			
		||||
    case float16:
 | 
			
		||||
      return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
 | 
			
		||||
    case bfloat16:
 | 
			
		||||
      return mlx_to_nd_array_impl<bfloat16_t, NDParams...>(a, nb::bfloat16);
 | 
			
		||||
      throw nb::type_error(
 | 
			
		||||
          "bfloat16 arrays cannot be converted directly to NumPy.");
 | 
			
		||||
    case float32:
 | 
			
		||||
      return mlx_to_nd_array_impl<float, NDParams...>(a);
 | 
			
		||||
    case complex64:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user