mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
@@ -196,6 +196,7 @@ auto mlx_expand_ellipsis(
|
||||
int non_none_indices_after = 0;
|
||||
std::vector<nb::object> r_indices;
|
||||
int i = 0;
|
||||
bool has_ellipsis = false;
|
||||
|
||||
// Start from dimension 0 till we hit an ellipsis
|
||||
for (; i < entries.size(); i++) {
|
||||
@@ -208,6 +209,7 @@ auto mlx_expand_ellipsis(
|
||||
indices.push_back(idx);
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
has_ellipsis = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -231,11 +233,13 @@ auto mlx_expand_ellipsis(
|
||||
int non_none_indices = non_none_indices_before + non_none_indices_after;
|
||||
|
||||
// Expand ellipsis
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < shape.size() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(nb::slice(0, shape[axis], 1));
|
||||
non_none_indices++;
|
||||
if (has_ellipsis) {
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < shape.size() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(nb::slice(0, shape[axis], 1));
|
||||
non_none_indices++;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert indices collected after the ellipsis
|
||||
@@ -409,6 +413,10 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
out_shape.push_back(src.shape(axis++));
|
||||
}
|
||||
}
|
||||
|
||||
out_shape.insert(
|
||||
out_shape.end(), src.shape().begin() + axis, src.shape().end());
|
||||
|
||||
src = reshape(src, out_shape);
|
||||
}
|
||||
|
||||
@@ -580,6 +588,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
int num_slices = 0;
|
||||
int num_arrays = 0;
|
||||
int num_strided_slices = 0;
|
||||
int num_simple_slices_post = 0;
|
||||
{
|
||||
bool have_array = false;
|
||||
bool have_non_array = false;
|
||||
@@ -595,6 +604,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
auto slice = nb::cast<nb::slice>(idx);
|
||||
int stride = get_slice_int(nb::getattr(slice, "step"), 1);
|
||||
num_strided_slices += (stride != 1);
|
||||
num_simple_slices_post += (stride == 1);
|
||||
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
have_array = true;
|
||||
@@ -603,16 +613,17 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
}
|
||||
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
|
||||
num_arrays++;
|
||||
num_simple_slices_post = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We have index dims for the arrays, strided slices (implemented as arrays),
|
||||
// none
|
||||
int idx_ndim = max_dim + num_strided_slices + num_none;
|
||||
int idx_ndim = max_dim + num_none + num_slices - num_simple_slices_post;
|
||||
|
||||
// If we have simple non-strided slices, we also attach an index for that
|
||||
idx_ndim += (num_slices < num_strided_slices);
|
||||
idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
|
||||
|
||||
// Go over each index type and translate to the needed scatter args
|
||||
std::vector<array> arr_indices;
|
||||
@@ -639,22 +650,27 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
std::vector<int> idx_shape(idx_ndim, 1);
|
||||
|
||||
// If it's a simple slice, we only need to add the start index
|
||||
if (stride == 1) {
|
||||
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
|
||||
auto idx = array({start}, idx_shape, uint32);
|
||||
slice_shapes.push_back(end - start);
|
||||
arr_indices.push_back(idx);
|
||||
|
||||
// Add the shape to the update
|
||||
update_shape[ax - 1] = slice_shapes.back();
|
||||
}
|
||||
// Otherwise we expand the slice into indices using arange
|
||||
else {
|
||||
auto idx = arange(start, end, stride, uint32);
|
||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||
slice_num++;
|
||||
idx_shape[loc] = idx.size();
|
||||
slice_shapes.push_back(idx.size());
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
|
||||
slice_num++;
|
||||
num_strided_slices--;
|
||||
|
||||
// Add the shape to the update
|
||||
update_shape[ax - 1] = 1;
|
||||
}
|
||||
// Add the shape to the update
|
||||
update_shape[ax - 1] = slice_shapes.back();
|
||||
} else if (nb::isinstance<nb::int_>(pyidx)) {
|
||||
// Add index to arrays
|
||||
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
|
||||
|
Reference in New Issue
Block a user