Indexing bug (#233)

* fix

* test
This commit is contained in:
Awni Hannun 2023-12-20 10:44:01 -08:00 committed by GitHub
parent 2807c6aff0
commit f40d17047d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 0 deletions

View File

@ -120,6 +120,11 @@ array mlx_gather_nd(
if (py::isinstance<py::slice>(idx)) {
int start, end, stride;
get_slice_params(start, end, stride, idx, src.shape(i));
// Handle negative indices
start = (start < 0) ? start + src.shape(i) : start;
end = (end < 0) ? end + src.shape(i) : end;
gather_indices.push_back(arange(start, end, stride, uint32));
num_slices++;
is_slice[i] = true;

View File

@ -727,6 +727,11 @@ class TestArray(mlx_tests.MLXTestCase):
np.array_equal(a_np[idx_np, idx_np], np.array(a_mlx[idx_mlx, idx_mlx]))
)
# Slicing with negative indices and integer
a_np = np.arange(10).reshape(5, 2)
a_mlx = mx.array(a_np)
self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0])))
def test_setitem(self):
a = mx.array(0)
a[None] = 1