mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	@@ -196,6 +196,7 @@ auto mlx_expand_ellipsis(
 | 
			
		||||
  int non_none_indices_after = 0;
 | 
			
		||||
  std::vector<nb::object> r_indices;
 | 
			
		||||
  int i = 0;
 | 
			
		||||
  bool has_ellipsis = false;
 | 
			
		||||
 | 
			
		||||
  // Start from dimension 0 till we hit an ellipsis
 | 
			
		||||
  for (; i < entries.size(); i++) {
 | 
			
		||||
@@ -208,6 +209,7 @@ auto mlx_expand_ellipsis(
 | 
			
		||||
      indices.push_back(idx);
 | 
			
		||||
      non_none_indices_before += !idx.is_none();
 | 
			
		||||
    } else {
 | 
			
		||||
      has_ellipsis = true;
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@@ -231,11 +233,13 @@ auto mlx_expand_ellipsis(
 | 
			
		||||
  int non_none_indices = non_none_indices_before + non_none_indices_after;
 | 
			
		||||
 | 
			
		||||
  // Expand ellipsis
 | 
			
		||||
  for (int axis = non_none_indices_before;
 | 
			
		||||
       axis < shape.size() - non_none_indices_after;
 | 
			
		||||
       axis++) {
 | 
			
		||||
    indices.push_back(nb::slice(0, shape[axis], 1));
 | 
			
		||||
    non_none_indices++;
 | 
			
		||||
  if (has_ellipsis) {
 | 
			
		||||
    for (int axis = non_none_indices_before;
 | 
			
		||||
         axis < shape.size() - non_none_indices_after;
 | 
			
		||||
         axis++) {
 | 
			
		||||
      indices.push_back(nb::slice(0, shape[axis], 1));
 | 
			
		||||
      non_none_indices++;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Insert indices collected after the ellipsis
 | 
			
		||||
@@ -409,6 +413,10 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
 | 
			
		||||
        out_shape.push_back(src.shape(axis++));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    out_shape.insert(
 | 
			
		||||
        out_shape.end(), src.shape().begin() + axis, src.shape().end());
 | 
			
		||||
 | 
			
		||||
    src = reshape(src, out_shape);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -580,6 +588,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
 | 
			
		||||
  int num_slices = 0;
 | 
			
		||||
  int num_arrays = 0;
 | 
			
		||||
  int num_strided_slices = 0;
 | 
			
		||||
  int num_simple_slices_post = 0;
 | 
			
		||||
  {
 | 
			
		||||
    bool have_array = false;
 | 
			
		||||
    bool have_non_array = false;
 | 
			
		||||
@@ -595,6 +604,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
 | 
			
		||||
        auto slice = nb::cast<nb::slice>(idx);
 | 
			
		||||
        int stride = get_slice_int(nb::getattr(slice, "step"), 1);
 | 
			
		||||
        num_strided_slices += (stride != 1);
 | 
			
		||||
        num_simple_slices_post += (stride == 1);
 | 
			
		||||
 | 
			
		||||
      } else if (nb::isinstance<array>(idx)) {
 | 
			
		||||
        have_array = true;
 | 
			
		||||
@@ -603,16 +613,17 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
 | 
			
		||||
        }
 | 
			
		||||
        max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
 | 
			
		||||
        num_arrays++;
 | 
			
		||||
        num_simple_slices_post = 0;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // We have index dims for the arrays, strided slices (implemented as arrays),
 | 
			
		||||
  // none
 | 
			
		||||
  int idx_ndim = max_dim + num_strided_slices + num_none;
 | 
			
		||||
  int idx_ndim = max_dim + num_none + num_slices - num_simple_slices_post;
 | 
			
		||||
 | 
			
		||||
  // If we have simple non-strided slices, we also attach an index for that
 | 
			
		||||
  idx_ndim += (num_slices < num_strided_slices);
 | 
			
		||||
  idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
 | 
			
		||||
 | 
			
		||||
  // Go over each index type and translate to the needed scatter args
 | 
			
		||||
  std::vector<array> arr_indices;
 | 
			
		||||
@@ -639,22 +650,27 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
 | 
			
		||||
      std::vector<int> idx_shape(idx_ndim, 1);
 | 
			
		||||
 | 
			
		||||
      // If it's a simple slice, we only need to add the start index
 | 
			
		||||
      if (stride == 1) {
 | 
			
		||||
      if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
 | 
			
		||||
        auto idx = array({start}, idx_shape, uint32);
 | 
			
		||||
        slice_shapes.push_back(end - start);
 | 
			
		||||
        arr_indices.push_back(idx);
 | 
			
		||||
 | 
			
		||||
        // Add the shape to the update
 | 
			
		||||
        update_shape[ax - 1] = slice_shapes.back();
 | 
			
		||||
      }
 | 
			
		||||
      // Otherwise we expand the slice into indices using arange
 | 
			
		||||
      else {
 | 
			
		||||
        auto idx = arange(start, end, stride, uint32);
 | 
			
		||||
        auto loc = slice_num + (arrays_first ? max_dim : 0);
 | 
			
		||||
        slice_num++;
 | 
			
		||||
        idx_shape[loc] = idx.size();
 | 
			
		||||
        slice_shapes.push_back(idx.size());
 | 
			
		||||
        arr_indices.push_back(reshape(idx, idx_shape));
 | 
			
		||||
 | 
			
		||||
        slice_num++;
 | 
			
		||||
        num_strided_slices--;
 | 
			
		||||
 | 
			
		||||
        // Add the shape to the update
 | 
			
		||||
        update_shape[ax - 1] = 1;
 | 
			
		||||
      }
 | 
			
		||||
      // Add the shape to the update
 | 
			
		||||
      update_shape[ax - 1] = slice_shapes.back();
 | 
			
		||||
    } else if (nb::isinstance<nb::int_>(pyidx)) {
 | 
			
		||||
      // Add index to arrays
 | 
			
		||||
      arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user