mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
		| @@ -27,10 +27,18 @@ struct ndarray_traits<float16_t> { | ||||
| static constexpr dlpack::dtype bfloat16{4, 16, 1}; | ||||
| }; // namespace nanobind | ||||
|  | ||||
| int check_shape_dim(int64_t dim) { | ||||
|   if (dim > std::numeric_limits<int>::max()) { | ||||
|     throw std::invalid_argument( | ||||
|         "Shape dimension falls outside supported `int` range."); | ||||
|   } | ||||
|   return static_cast<int>(dim); | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| array nd_array_to_mlx_contiguous( | ||||
|     nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array, | ||||
|     const std::vector<int>& shape, | ||||
|     const Shape& shape, | ||||
|     Dtype dtype) { | ||||
|   // Make a copy of the numpy buffer | ||||
|   // Get buffer ptr pass to array constructor | ||||
| @@ -42,7 +50,7 @@ array nd_array_to_mlx( | ||||
|     nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array, | ||||
|     std::optional<Dtype> dtype) { | ||||
|   // Compute the shape and size | ||||
|   std::vector<int> shape; | ||||
|   Shape shape; | ||||
|   for (int i = 0; i < nd_array.ndim(); i++) { | ||||
|     shape.push_back(check_shape_dim(nd_array.shape(i))); | ||||
|   } | ||||
| @@ -108,13 +116,12 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl( | ||||
|     a.eval(); | ||||
|   } | ||||
|   std::vector<size_t> shape(a.shape().begin(), a.shape().end()); | ||||
|   std::vector<int64_t> strides(a.strides().begin(), a.strides().end()); | ||||
|   return nb::ndarray<NDParams...>( | ||||
|       a.data<T>(), | ||||
|       a.ndim(), | ||||
|       shape.data(), | ||||
|       /* owner= */ nb::none(), | ||||
|       strides.data(), | ||||
|       a.strides().data(), | ||||
|       t.value_or(nb::dtype<T>())); | ||||
| } | ||||
|  | ||||
| @@ -272,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) { | ||||
| template <typename T> | ||||
| PyScalarT validate_shape( | ||||
|     T list, | ||||
|     const std::vector<int>& shape, | ||||
|     const Shape& shape, | ||||
|     int idx, | ||||
|     bool& all_python_primitive_elements) { | ||||
|   if (idx >= shape.size()) { | ||||
| @@ -340,7 +347,7 @@ PyScalarT validate_shape( | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| void get_shape(T list, std::vector<int>& shape) { | ||||
| void get_shape(T list, Shape& shape) { | ||||
|   shape.push_back(check_shape_dim(nb::len(list))); | ||||
|   if (shape.back() > 0) { | ||||
|     auto l = list.begin(); | ||||
| @@ -351,7 +358,7 @@ void get_shape(T list, std::vector<int>& shape) { | ||||
|     } else if (nb::isinstance<array>(*l)) { | ||||
|       auto arr = nb::cast<array>(*l); | ||||
|       for (int i = 0; i < arr.ndim(); i++) { | ||||
|         shape.push_back(check_shape_dim(arr.shape(i))); | ||||
|         shape.push_back(arr.shape(i)); | ||||
|       } | ||||
|       return; | ||||
|     } | ||||
| @@ -363,7 +370,7 @@ array array_from_list_impl( | ||||
|     T pl, | ||||
|     const PyScalarT& inferred_type, | ||||
|     std::optional<Dtype> specified_type, | ||||
|     const std::vector<int>& shape) { | ||||
|     const Shape& shape) { | ||||
|   // Make the array | ||||
|   switch (inferred_type) { | ||||
|     case pybool: { | ||||
| @@ -420,7 +427,7 @@ array array_from_list_impl( | ||||
| template <typename T> | ||||
| array array_from_list_impl(T pl, std::optional<Dtype> dtype) { | ||||
|   // Compute the shape | ||||
|   std::vector<int> shape; | ||||
|   Shape shape; | ||||
|   get_shape(pl, shape); | ||||
|  | ||||
|   // Validate the shape and type | ||||
|   | ||||
| @@ -2953,16 +2953,16 @@ void init_ops(nb::module_& m) { | ||||
|   m.def( | ||||
|       "as_strided", | ||||
|       [](const array& a, | ||||
|          std::optional<std::vector<int>> shape, | ||||
|          std::optional<std::vector<size_t>> strides, | ||||
|          std::optional<Shape> shape, | ||||
|          std::optional<Strides> strides, | ||||
|          size_t offset, | ||||
|          StreamOrDevice s) { | ||||
|         std::vector<int> a_shape = (shape) ? *shape : a.shape(); | ||||
|         std::vector<size_t> a_strides; | ||||
|         auto a_shape = (shape) ? *shape : a.shape(); | ||||
|         Strides a_strides; | ||||
|         if (strides) { | ||||
|           a_strides = *strides; | ||||
|         } else { | ||||
|           a_strides = std::vector<size_t>(a_shape.size(), 1); | ||||
|           a_strides = Strides(a_shape.size(), 1); | ||||
|           for (int i = a_shape.size() - 1; i > 0; i--) { | ||||
|             a_strides[i - 1] = a_shape[i] * a_strides[i]; | ||||
|           } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun