mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places * few more * few more * fix
This commit is contained in:
		| @@ -176,7 +176,7 @@ mx::array mlx_gather_nd( | ||||
|   for (auto& ax : axes) { | ||||
|     ax += max_dims + num_slices; | ||||
|   } | ||||
|   return squeeze(src, axes); | ||||
|   return mx::squeeze(src, axes); | ||||
| } | ||||
|  | ||||
| auto mlx_expand_ellipsis( | ||||
| @@ -438,9 +438,7 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { | ||||
|   } else if (nb::isinstance<nb::ellipsis>(obj)) { | ||||
|     return src; | ||||
|   } else if (obj.is_none()) { | ||||
|     std::vector<int> s(1, 1); | ||||
|     s.insert(s.end(), src.shape().begin(), src.shape().end()); | ||||
|     return reshape(src, s); | ||||
|     return expand_dims(src, 0); | ||||
|   } else if (nb::isinstance<nb::list>(obj)) { | ||||
|     return mlx_get_item_array( | ||||
|         src, array_from_list(nb::cast<nb::list>(obj), {})); | ||||
| @@ -474,6 +472,15 @@ mlx_scatter_args_int( | ||||
|       {0}}; | ||||
| } | ||||
|  | ||||
| mx::array squeeze_leading_singletons(const mx::array& in) { | ||||
|   int s = 0; | ||||
|   for (; s < in.ndim() && in.shape(s) == 1; s++) | ||||
|     ; | ||||
|   auto squeeze_axes = std::vector<int>(s); | ||||
|   std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0); | ||||
|   return mx::squeeze(in, squeeze_axes); | ||||
| } | ||||
|  | ||||
| std::tuple<std::vector<mx::array>, mx::array, std::vector<int>> | ||||
| mlx_scatter_args_array( | ||||
|     const mx::array& src, | ||||
| @@ -484,16 +491,10 @@ mlx_scatter_args_array( | ||||
|         "too many indices for array: array is 0-dimensional"); | ||||
|   } | ||||
|  | ||||
|   // Remove any leading singleton dimensions from the update | ||||
|   int s = 0; | ||||
|   for (; s < update.ndim() && update.shape(s) == 1; s++) | ||||
|     ; | ||||
|   auto up_shape = | ||||
|       std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|   auto up = reshape(update, up_shape); | ||||
|   auto up = squeeze_leading_singletons(update); | ||||
|  | ||||
|   // The update shape must broadcast with indices.shape + [1] + src.shape[1:] | ||||
|   up_shape = indices.shape(); | ||||
|   auto up_shape = indices.shape(); | ||||
|   up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end()); | ||||
|   up = broadcast_to(up, up_shape); | ||||
|   up_shape.insert(up_shape.begin() + indices.ndim(), 1); | ||||
| @@ -516,12 +517,8 @@ mlx_scatter_args_slice( | ||||
|   // If none slice is requested broadcast the update | ||||
|   // to the src size and return it. | ||||
|   if (is_none_slice(in_slice)) { | ||||
|     int s = 0; | ||||
|     for (; s < update.ndim() && update.shape(s) == 1; s++) | ||||
|       ; | ||||
|     auto up_shape = | ||||
|         std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|     return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}}; | ||||
|     return { | ||||
|         {}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}}; | ||||
|   } | ||||
|  | ||||
|   int start = 0; | ||||
| @@ -534,12 +531,7 @@ mlx_scatter_args_slice( | ||||
|   // If simple stride | ||||
|   if (stride == 1) { | ||||
|     // Squeeze out singleton dims from the start of update | ||||
|     int s = 0; | ||||
|     for (; s < update.ndim() && update.shape(s) == 1; s++) | ||||
|       ; | ||||
|     auto up_shape = | ||||
|         std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|     auto up = reshape(update, up_shape); | ||||
|     auto up = squeeze_leading_singletons(update); | ||||
|  | ||||
|     // Build array to mark start of slice | ||||
|     auto idx = mx::array({start}, {1}, mx::uint32); | ||||
| @@ -548,7 +540,7 @@ mlx_scatter_args_slice( | ||||
|     int slice_size = (end - start); | ||||
|  | ||||
|     // Broadcast update to slice size | ||||
|     std::vector<int> up_shape_broadcast = {1, slice_size}; | ||||
|     mx::Shape up_shape_broadcast = {1, slice_size}; | ||||
|     up_shape_broadcast.insert( | ||||
|         up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end()); | ||||
|  | ||||
| @@ -585,13 +577,7 @@ mlx_scatter_args_nd( | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
|   // Remove leading singletons dimensions from the update | ||||
|   int s = 0; | ||||
|   for (; s < update.ndim() && update.shape(s) == 1; s++) { | ||||
|   }; | ||||
|   auto up_shape = | ||||
|       std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|   auto up = reshape(update, up_shape); | ||||
|   auto up = squeeze_leading_singletons(update); | ||||
|  | ||||
|   // If no non-None indices return the broadcasted update | ||||
|   if (non_none_indices == 0) { | ||||
| @@ -703,7 +689,7 @@ mlx_scatter_args_nd( | ||||
|     } else if (nb::isinstance<mx::array>(pyidx)) { | ||||
|       ax++; | ||||
|       auto idx = nb::cast<mx::array>(pyidx); | ||||
|       std::vector<int> idx_shape(idx_ndim, 1); | ||||
|       mx::Shape idx_shape(idx_ndim, 1); | ||||
|  | ||||
|       // Place the arrays in the correct dimension | ||||
|       int st = (!arrays_first) * slice_num + max_dim - idx.ndim(); | ||||
| @@ -801,17 +787,18 @@ auto mlx_slice_update( | ||||
|  | ||||
|   // Remove extra leading singletons dimensions from the update | ||||
|   int s = 0; | ||||
|   for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim(); | ||||
|   for (; s < static_cast<int>(upd.ndim()) - 1 && upd.shape(s) == 1 && | ||||
|        (upd.ndim() - s) > src.ndim(); | ||||
|        s++) { | ||||
|   }; | ||||
|   auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end()); | ||||
|   up_shape = up_shape.empty() ? std::vector{1} : up_shape; | ||||
|   auto up = reshape(upd, up_shape); | ||||
|   auto squeeze_axes = std::vector<int>(s); | ||||
|   std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0); | ||||
|   auto up = mx::squeeze(upd, squeeze_axes); | ||||
|  | ||||
|   // Build slice update params | ||||
|   std::vector<int> starts(src.ndim(), 0); | ||||
|   std::vector<int> stops = src.shape(); | ||||
|   std::vector<int> strides(src.ndim(), 1); | ||||
|   mx::Shape starts(src.ndim(), 0); | ||||
|   mx::Shape stops = src.shape(); | ||||
|   mx::Shape strides(src.ndim(), 1); | ||||
|  | ||||
|   // If it's just a simple slice, just do a slice update and return | ||||
|   if (nb::isinstance<nb::slice>(obj)) { | ||||
| @@ -847,7 +834,7 @@ auto mlx_slice_update( | ||||
|   } | ||||
|  | ||||
|   // Process entries | ||||
|   std::vector<int> up_reshape(src.ndim()); | ||||
|   mx::Shape up_reshape(src.ndim()); | ||||
|   int ax = src.ndim() - 1; | ||||
|   int up_ax = up.ndim() - 1; | ||||
|   for (; ax >= non_none_indices; ax--) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun