diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index b60b23cfa..446f17930 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -606,8 +606,12 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( auto slice = nb::cast(idx); int stride = get_slice_int(nb::getattr(slice, "step"), 1); - num_strided_slices += (stride != 1); - num_simple_slices_post += (stride == 1); + if (stride != 1) { + num_strided_slices++; + num_simple_slices_post = 0; + } else { + num_simple_slices_post++; + } } else if (nb::isinstance(idx)) { have_array = true; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 587a98e2e..8f95edf2b 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1210,6 +1210,23 @@ class TestArray(mlx_tests.MLXTestCase): np.array([1, 3]), ) + check_slices( + np.zeros((3, 4, 5, 3)), + np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3), + np.array([2, 1]), + slice(None, None, None), + slice(None, None, 2), + slice(None, None, None), + ) + + check_slices( + np.zeros((3, 4, 5, 3)), + np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3), + np.array([2, 1]), + slice(None, None, None), + slice(None, None, 2), + ) + def test_array_at(self): a = mx.array(1) a = a.at[None].add(1)