Set item bug fix (#879)

* set item shaping bug fix

* Add extra tests
This commit is contained in:
Jagrit Digani 2024-03-22 12:11:17 -07:00 committed by GitHub
parent fcda3a0e66
commit 8e5a5a1ccd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 12 deletions

View File

@ -196,6 +196,7 @@ auto mlx_expand_ellipsis(
int non_none_indices_after = 0; int non_none_indices_after = 0;
std::vector<nb::object> r_indices; std::vector<nb::object> r_indices;
int i = 0; int i = 0;
bool has_ellipsis = false;
// Start from dimension 0 till we hit an ellipsis // Start from dimension 0 till we hit an ellipsis
for (; i < entries.size(); i++) { for (; i < entries.size(); i++) {
@ -208,6 +209,7 @@ auto mlx_expand_ellipsis(
indices.push_back(idx); indices.push_back(idx);
non_none_indices_before += !idx.is_none(); non_none_indices_before += !idx.is_none();
} else { } else {
has_ellipsis = true;
break; break;
} }
} }
@ -231,12 +233,14 @@ auto mlx_expand_ellipsis(
int non_none_indices = non_none_indices_before + non_none_indices_after; int non_none_indices = non_none_indices_before + non_none_indices_after;
// Expand ellipsis // Expand ellipsis
if (has_ellipsis) {
for (int axis = non_none_indices_before; for (int axis = non_none_indices_before;
axis < shape.size() - non_none_indices_after; axis < shape.size() - non_none_indices_after;
axis++) { axis++) {
indices.push_back(nb::slice(0, shape[axis], 1)); indices.push_back(nb::slice(0, shape[axis], 1));
non_none_indices++; non_none_indices++;
} }
}
// Insert indices collected after the ellipsis // Insert indices collected after the ellipsis
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend()); indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
@ -409,6 +413,10 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
out_shape.push_back(src.shape(axis++)); out_shape.push_back(src.shape(axis++));
} }
} }
out_shape.insert(
out_shape.end(), src.shape().begin() + axis, src.shape().end());
src = reshape(src, out_shape); src = reshape(src, out_shape);
} }
@ -580,6 +588,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
int num_slices = 0; int num_slices = 0;
int num_arrays = 0; int num_arrays = 0;
int num_strided_slices = 0; int num_strided_slices = 0;
int num_simple_slices_post = 0;
{ {
bool have_array = false; bool have_array = false;
bool have_non_array = false; bool have_non_array = false;
@ -595,6 +604,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
auto slice = nb::cast<nb::slice>(idx); auto slice = nb::cast<nb::slice>(idx);
int stride = get_slice_int(nb::getattr(slice, "step"), 1); int stride = get_slice_int(nb::getattr(slice, "step"), 1);
num_strided_slices += (stride != 1); num_strided_slices += (stride != 1);
num_simple_slices_post += (stride == 1);
} else if (nb::isinstance<array>(idx)) { } else if (nb::isinstance<array>(idx)) {
have_array = true; have_array = true;
@ -603,16 +613,17 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
} }
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim); max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
num_arrays++; num_arrays++;
num_simple_slices_post = 0;
} }
} }
} }
// We have index dims for the arrays, strided slices (implemented as arrays), // We have index dims for the arrays, strided slices (implemented as arrays),
// none // none
int idx_ndim = max_dim + num_strided_slices + num_none; int idx_ndim = max_dim + num_none + num_slices - num_simple_slices_post;
// If we have simple non-strided slices, we also attach an index for that // If we have simple non-strided slices, we also attach an index for that
idx_ndim += (num_slices < num_strided_slices); idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
// Go over each index type and translate to the needed scatter args // Go over each index type and translate to the needed scatter args
std::vector<array> arr_indices; std::vector<array> arr_indices;
@ -639,22 +650,27 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
std::vector<int> idx_shape(idx_ndim, 1); std::vector<int> idx_shape(idx_ndim, 1);
// If it's a simple slice, we only need to add the start index // If it's a simple slice, we only need to add the start index
if (stride == 1) { if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
auto idx = array({start}, idx_shape, uint32); auto idx = array({start}, idx_shape, uint32);
slice_shapes.push_back(end - start); slice_shapes.push_back(end - start);
arr_indices.push_back(idx); arr_indices.push_back(idx);
// Add the shape to the update
update_shape[ax - 1] = slice_shapes.back();
} }
// Otherwise we expand the slice into indices using arange // Otherwise we expand the slice into indices using arange
else { else {
auto idx = arange(start, end, stride, uint32); auto idx = arange(start, end, stride, uint32);
auto loc = slice_num + (arrays_first ? max_dim : 0); auto loc = slice_num + (arrays_first ? max_dim : 0);
slice_num++;
idx_shape[loc] = idx.size(); idx_shape[loc] = idx.size();
slice_shapes.push_back(idx.size());
arr_indices.push_back(reshape(idx, idx_shape)); arr_indices.push_back(reshape(idx, idx_shape));
}
slice_num++;
num_strided_slices--;
// Add the shape to the update // Add the shape to the update
update_shape[ax - 1] = slice_shapes.back(); update_shape[ax - 1] = 1;
}
} else if (nb::isinstance<nb::int_>(pyidx)) { } else if (nb::isinstance<nb::int_>(pyidx)) {
// Add index to arrays // Add index to arrays
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));

View File

@ -1058,6 +1058,29 @@ class TestArray(mlx_tests.MLXTestCase):
a[2:-2, 2:-2] = 4 a[2:-2, 2:-2] = 4
self.assertEqual(a[2, 2].item(), 4) self.assertEqual(a[2, 2].item(), 4)
# Check slice array slice
check_slices(
np.zeros((5, 4, 4)),
np.arange(4 * 2 * 3).reshape(4, 2, 3),
slice(0, 4),
np.array([1, 3]),
slice(None, -1),
)
check_slices(
np.zeros((5, 4, 4)),
np.arange(4 * 2 * 2).reshape(4, 2, 2),
slice(0, 4),
np.array([1, 3]),
slice(0, 4, 2),
)
check_slices(
np.zeros((1, 10, 4)),
np.arange(2 * 4).reshape(1, 2, 4),
slice(None, None, None),
np.array([1, 3]),
)
def test_array_at(self): def test_array_at(self):
a = mx.array(1) a = mx.array(1)
a = a.at[None].add(1) a = a.at[None].add(1)