mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
@@ -120,6 +120,11 @@ array mlx_gather_nd(
|
||||
if (py::isinstance<py::slice>(idx)) {
|
||||
int start, end, stride;
|
||||
get_slice_params(start, end, stride, idx, src.shape(i));
|
||||
|
||||
// Handle negative indices
|
||||
start = (start < 0) ? start + src.shape(i) : start;
|
||||
end = (end < 0) ? end + src.shape(i) : end;
|
||||
|
||||
gather_indices.push_back(arange(start, end, stride, uint32));
|
||||
num_slices++;
|
||||
is_slice[i] = true;
|
||||
|
Reference in New Issue
Block a user