Indexing bug (#233)

* fix

* test
This commit is contained in:
Awni Hannun
2023-12-20 10:44:01 -08:00
committed by GitHub
parent 2807c6aff0
commit f40d17047d
2 changed files with 10 additions and 0 deletions

View File

@@ -120,6 +120,11 @@ array mlx_gather_nd(
if (py::isinstance<py::slice>(idx)) {
int start, end, stride;
get_slice_params(start, end, stride, idx, src.shape(i));
// Handle negative indices
start = (start < 0) ? start + src.shape(i) : start;
end = (end < 0) ? end + src.shape(i) : end;
gather_indices.push_back(arange(start, end, stride, uint32));
num_slices++;
is_slice[i] = true;