From 104c34f906cbff529df259973b8ccf9ded5efdf6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 16 Dec 2023 06:44:47 -0800 Subject: [PATCH] setite negative indexing bug (#189) --- python/src/indexing.cpp | 11 +++++++---- python/tests/test_array.py | 5 +++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 788e2fd5f..d66c9a0e8 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -41,9 +41,6 @@ void get_slice_params( py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0); ends = get_slice_int( py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size); - - // starts = (starts < 0) ? starts + axis_size : starts; - // ends = (ends < 0) ? ends + axis_size : ends; } array get_int_index(py::object idx, int axis_size) { @@ -568,7 +565,13 @@ array mlx_set_item_nd( auto& pyidx = indices[i]; if (py::isinstance(pyidx)) { int start, end, stride; - get_slice_params(start, end, stride, pyidx, src.shape(ax++)); + auto axis_size = src.shape(ax++); + get_slice_params(start, end, stride, pyidx, axis_size); + + // Handle negative indices + start = (start < 0) ? start + axis_size : start; + end = (end < 0) ? end + axis_size : end; + auto idx = arange(start, end, stride, uint32); std::vector idx_shape(max_dim + num_slices, 1); auto loc = slice_num + (arrays_first ? max_dim : 0); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index addec3493..cabc6a114 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -903,6 +903,11 @@ class TestArray(mlx_tests.MLXTestCase): np.array([0, 1]), ) + # Check slice assign with negative indices works + a = mx.zeros((5, 5), mx.int32) + a[2:-2, 2:-2] = 4 + self.assertEquals(a[2, 2].item(), 4) + def test_slice_negative_step(self): a_np = np.arange(20) a_mx = mx.array(a_np)