Indexing bug fix (#947)

* Fix axes accounting

* Add tests
This commit is contained in:
Jagrit Digani 2024-04-01 12:18:50 -07:00 committed by GitHub
parent 02fedbf1da
commit 639e06e1f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 2 deletions

View File

@ -606,8 +606,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
auto slice = nb::cast<nb::slice>(idx);
int stride = get_slice_int(nb::getattr(slice, "step"), 1);
num_strided_slices += (stride != 1);
num_simple_slices_post += (stride == 1);
if (stride != 1) {
num_strided_slices++;
num_simple_slices_post = 0;
} else {
num_simple_slices_post++;
}
} else if (nb::isinstance<array>(idx)) {
have_array = true;

View File

@ -1210,6 +1210,23 @@ class TestArray(mlx_tests.MLXTestCase):
np.array([1, 3]),
)
check_slices(
np.zeros((3, 4, 5, 3)),
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
np.array([2, 1]),
slice(None, None, None),
slice(None, None, 2),
slice(None, None, None),
)
check_slices(
np.zeros((3, 4, 5, 3)),
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
np.array([2, 1]),
slice(None, None, None),
slice(None, None, 2),
)
def test_array_at(self):
a = mx.array(1)
a = a.at[None].add(1)