mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	* Ensure shape dimensions are within supported integer range (#566) * fix build * fix rebase bug --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -196,7 +196,7 @@ PyScalarT validate_shape( | ||||
|  | ||||
| template <typename T> | ||||
| void get_shape(T list, std::vector<int>& shape) { | ||||
|   shape.push_back(nb::len(list)); | ||||
|   shape.push_back(check_shape_dim(nb::len(list))); | ||||
|   if (shape.back() > 0) { | ||||
|     auto l = list.begin(); | ||||
|     if (nb::isinstance<nb::list>(*l)) { | ||||
| @@ -205,7 +205,9 @@ void get_shape(T list, std::vector<int>& shape) { | ||||
|       return get_shape(nb::cast<nb::tuple>(*l), shape); | ||||
|     } else if (nb::isinstance<array>(*l)) { | ||||
|       auto arr = nb::cast<array>(*l); | ||||
|       shape.insert(shape.end(), arr.shape().begin(), arr.shape().end()); | ||||
|       for (int i = 0; i < arr.ndim(); i++) { | ||||
|         shape.push_back(check_shape_dim(arr.shape(i))); | ||||
|       } | ||||
|       return; | ||||
|     } | ||||
|   } | ||||
|   | ||||
| @@ -4,6 +4,8 @@ | ||||
|  | ||||
| #include "python/src/convert.h" | ||||
|  | ||||
| #include "mlx/utils.h" | ||||
|  | ||||
| namespace nanobind { | ||||
| template <> | ||||
| struct ndarray_traits<float16_t> { | ||||
| @@ -43,7 +45,7 @@ array nd_array_to_mlx( | ||||
|   // Compute the shape and size | ||||
|   std::vector<int> shape; | ||||
|   for (int i = 0; i < nd_array.ndim(); i++) { | ||||
|     shape.push_back(nd_array.shape(i)); | ||||
|     shape.push_back(check_shape_dim(nd_array.shape(i))); | ||||
|   } | ||||
|   auto type = nd_array.dtype(); | ||||
|  | ||||
|   | ||||
| @@ -534,6 +534,14 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|  | ||||
|             self.assertEqual(b_npy.dtype, np_dtype) | ||||
|  | ||||
|     def test_array_np_shape_dim_check(self): | ||||
|         a_npy = np.empty(2**31, dtype=np.bool_) | ||||
|         with self.assertRaises(ValueError) as e: | ||||
|             mx.array(a_npy) | ||||
|         self.assertEqual( | ||||
|             str(e.exception), "Shape dimension falls outside supported `int` range." | ||||
|         ) | ||||
|  | ||||
|     def test_dtype_promotion(self): | ||||
|         dtypes_list = [ | ||||
|             (mx.bool_, np.bool_), | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jack Mousseau
					Jack Mousseau