mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
parent
fcda3a0e66
commit
8e5a5a1ccd
@ -196,6 +196,7 @@ auto mlx_expand_ellipsis(
|
||||
int non_none_indices_after = 0;
|
||||
std::vector<nb::object> r_indices;
|
||||
int i = 0;
|
||||
bool has_ellipsis = false;
|
||||
|
||||
// Start from dimension 0 till we hit an ellipsis
|
||||
for (; i < entries.size(); i++) {
|
||||
@ -208,6 +209,7 @@ auto mlx_expand_ellipsis(
|
||||
indices.push_back(idx);
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
has_ellipsis = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -231,11 +233,13 @@ auto mlx_expand_ellipsis(
|
||||
int non_none_indices = non_none_indices_before + non_none_indices_after;
|
||||
|
||||
// Expand ellipsis
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < shape.size() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(nb::slice(0, shape[axis], 1));
|
||||
non_none_indices++;
|
||||
if (has_ellipsis) {
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < shape.size() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(nb::slice(0, shape[axis], 1));
|
||||
non_none_indices++;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert indices collected after the ellipsis
|
||||
@ -409,6 +413,10 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
out_shape.push_back(src.shape(axis++));
|
||||
}
|
||||
}
|
||||
|
||||
out_shape.insert(
|
||||
out_shape.end(), src.shape().begin() + axis, src.shape().end());
|
||||
|
||||
src = reshape(src, out_shape);
|
||||
}
|
||||
|
||||
@ -580,6 +588,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
int num_slices = 0;
|
||||
int num_arrays = 0;
|
||||
int num_strided_slices = 0;
|
||||
int num_simple_slices_post = 0;
|
||||
{
|
||||
bool have_array = false;
|
||||
bool have_non_array = false;
|
||||
@ -595,6 +604,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
auto slice = nb::cast<nb::slice>(idx);
|
||||
int stride = get_slice_int(nb::getattr(slice, "step"), 1);
|
||||
num_strided_slices += (stride != 1);
|
||||
num_simple_slices_post += (stride == 1);
|
||||
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
have_array = true;
|
||||
@ -603,16 +613,17 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
}
|
||||
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
|
||||
num_arrays++;
|
||||
num_simple_slices_post = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We have index dims for the arrays, strided slices (implemented as arrays),
|
||||
// none
|
||||
int idx_ndim = max_dim + num_strided_slices + num_none;
|
||||
int idx_ndim = max_dim + num_none + num_slices - num_simple_slices_post;
|
||||
|
||||
// If we have simple non-strided slices, we also attach an index for that
|
||||
idx_ndim += (num_slices < num_strided_slices);
|
||||
idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
|
||||
|
||||
// Go over each index type and translate to the needed scatter args
|
||||
std::vector<array> arr_indices;
|
||||
@ -639,22 +650,27 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
std::vector<int> idx_shape(idx_ndim, 1);
|
||||
|
||||
// If it's a simple slice, we only need to add the start index
|
||||
if (stride == 1) {
|
||||
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
|
||||
auto idx = array({start}, idx_shape, uint32);
|
||||
slice_shapes.push_back(end - start);
|
||||
arr_indices.push_back(idx);
|
||||
|
||||
// Add the shape to the update
|
||||
update_shape[ax - 1] = slice_shapes.back();
|
||||
}
|
||||
// Otherwise we expand the slice into indices using arange
|
||||
else {
|
||||
auto idx = arange(start, end, stride, uint32);
|
||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||
slice_num++;
|
||||
idx_shape[loc] = idx.size();
|
||||
slice_shapes.push_back(idx.size());
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
|
||||
slice_num++;
|
||||
num_strided_slices--;
|
||||
|
||||
// Add the shape to the update
|
||||
update_shape[ax - 1] = 1;
|
||||
}
|
||||
// Add the shape to the update
|
||||
update_shape[ax - 1] = slice_shapes.back();
|
||||
} else if (nb::isinstance<nb::int_>(pyidx)) {
|
||||
// Add index to arrays
|
||||
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
|
||||
|
@ -1058,6 +1058,29 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a[2:-2, 2:-2] = 4
|
||||
self.assertEqual(a[2, 2].item(), 4)
|
||||
|
||||
# Check slice array slice
|
||||
check_slices(
|
||||
np.zeros((5, 4, 4)),
|
||||
np.arange(4 * 2 * 3).reshape(4, 2, 3),
|
||||
slice(0, 4),
|
||||
np.array([1, 3]),
|
||||
slice(None, -1),
|
||||
)
|
||||
check_slices(
|
||||
np.zeros((5, 4, 4)),
|
||||
np.arange(4 * 2 * 2).reshape(4, 2, 2),
|
||||
slice(0, 4),
|
||||
np.array([1, 3]),
|
||||
slice(0, 4, 2),
|
||||
)
|
||||
|
||||
check_slices(
|
||||
np.zeros((1, 10, 4)),
|
||||
np.arange(2 * 4).reshape(1, 2, 4),
|
||||
slice(None, None, None),
|
||||
np.array([1, 3]),
|
||||
)
|
||||
|
||||
def test_array_at(self):
|
||||
a = mx.array(1)
|
||||
a = a.at[None].add(1)
|
||||
|
Loading…
Reference in New Issue
Block a user