Set item bug fix (#879)

* set item shaping bug fix

* Add extra tests
This commit is contained in:
Jagrit Digani 2024-03-22 12:11:17 -07:00 committed by GitHub
parent fcda3a0e66
commit 8e5a5a1ccd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 12 deletions

View File

@ -196,6 +196,7 @@ auto mlx_expand_ellipsis(
int non_none_indices_after = 0;
std::vector<nb::object> r_indices;
int i = 0;
bool has_ellipsis = false;
// Start from dimension 0 till we hit an ellipsis
for (; i < entries.size(); i++) {
@ -208,6 +209,7 @@ auto mlx_expand_ellipsis(
indices.push_back(idx);
non_none_indices_before += !idx.is_none();
} else {
has_ellipsis = true;
break;
}
}
@ -231,11 +233,13 @@ auto mlx_expand_ellipsis(
int non_none_indices = non_none_indices_before + non_none_indices_after;
// Expand ellipsis
for (int axis = non_none_indices_before;
axis < shape.size() - non_none_indices_after;
axis++) {
indices.push_back(nb::slice(0, shape[axis], 1));
non_none_indices++;
if (has_ellipsis) {
for (int axis = non_none_indices_before;
axis < shape.size() - non_none_indices_after;
axis++) {
indices.push_back(nb::slice(0, shape[axis], 1));
non_none_indices++;
}
}
// Insert indices collected after the ellipsis
@ -409,6 +413,10 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
out_shape.push_back(src.shape(axis++));
}
}
out_shape.insert(
out_shape.end(), src.shape().begin() + axis, src.shape().end());
src = reshape(src, out_shape);
}
@ -580,6 +588,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
int num_slices = 0;
int num_arrays = 0;
int num_strided_slices = 0;
int num_simple_slices_post = 0;
{
bool have_array = false;
bool have_non_array = false;
@ -595,6 +604,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
auto slice = nb::cast<nb::slice>(idx);
int stride = get_slice_int(nb::getattr(slice, "step"), 1);
num_strided_slices += (stride != 1);
num_simple_slices_post += (stride == 1);
} else if (nb::isinstance<array>(idx)) {
have_array = true;
@ -603,16 +613,17 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
}
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
num_arrays++;
num_simple_slices_post = 0;
}
}
}
// We have index dims for the arrays, strided slices (implemented as arrays),
// none
int idx_ndim = max_dim + num_strided_slices + num_none;
int idx_ndim = max_dim + num_none + num_slices - num_simple_slices_post;
// If we have simple non-strided slices, we also attach an index for that
idx_ndim += (num_slices < num_strided_slices);
idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
// Go over each index type and translate to the needed scatter args
std::vector<array> arr_indices;
@ -639,22 +650,27 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
std::vector<int> idx_shape(idx_ndim, 1);
// If it's a simple slice, we only need to add the start index
if (stride == 1) {
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
auto idx = array({start}, idx_shape, uint32);
slice_shapes.push_back(end - start);
arr_indices.push_back(idx);
// Add the shape to the update
update_shape[ax - 1] = slice_shapes.back();
}
// Otherwise we expand the slice into indices using arange
else {
auto idx = arange(start, end, stride, uint32);
auto loc = slice_num + (arrays_first ? max_dim : 0);
slice_num++;
idx_shape[loc] = idx.size();
slice_shapes.push_back(idx.size());
arr_indices.push_back(reshape(idx, idx_shape));
slice_num++;
num_strided_slices--;
// Add the shape to the update
update_shape[ax - 1] = 1;
}
// Add the shape to the update
update_shape[ax - 1] = slice_shapes.back();
} else if (nb::isinstance<nb::int_>(pyidx)) {
// Add index to arrays
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));

View File

@ -1058,6 +1058,29 @@ class TestArray(mlx_tests.MLXTestCase):
a[2:-2, 2:-2] = 4
self.assertEqual(a[2, 2].item(), 4)
# Check slice array slice
check_slices(
np.zeros((5, 4, 4)),
np.arange(4 * 2 * 3).reshape(4, 2, 3),
slice(0, 4),
np.array([1, 3]),
slice(None, -1),
)
check_slices(
np.zeros((5, 4, 4)),
np.arange(4 * 2 * 2).reshape(4, 2, 2),
slice(0, 4),
np.array([1, 3]),
slice(0, 4, 2),
)
check_slices(
np.zeros((1, 10, 4)),
np.arange(2 * 4).reshape(1, 2, 4),
slice(None, None, None),
np.array([1, 3]),
)
def test_array_at(self):
a = mx.array(1)
a = a.at[None].add(1)