From f40d17047de6efbecd60b0c92835a743fe209ba3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 20 Dec 2023 10:44:01 -0800 Subject: [PATCH] Indexing bug (#233) * fix * test --- python/src/indexing.cpp | 5 +++++ python/tests/test_array.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index d66c9a0e8..1454c5180 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -120,6 +120,11 @@ array mlx_gather_nd( if (py::isinstance(idx)) { int start, end, stride; get_slice_params(start, end, stride, idx, src.shape(i)); + + // Handle negative indices + start = (start < 0) ? start + src.shape(i) : start; + end = (end < 0) ? end + src.shape(i) : end; + gather_indices.push_back(arange(start, end, stride, uint32)); num_slices++; is_slice[i] = true; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 85f1aa257..fb6a24cbc 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -727,6 +727,11 @@ class TestArray(mlx_tests.MLXTestCase): np.array_equal(a_np[idx_np, idx_np], np.array(a_mlx[idx_mlx, idx_mlx])) ) + # Slicing with negative indices and integer + a_np = np.arange(10).reshape(5, 2) + a_mlx = mx.array(a_np) + self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0]))) + def test_setitem(self): a = mx.array(0) a[None] = 1