mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
2807c6aff0
commit
f40d17047d
@ -120,6 +120,11 @@ array mlx_gather_nd(
|
||||
if (py::isinstance<py::slice>(idx)) {
|
||||
int start, end, stride;
|
||||
get_slice_params(start, end, stride, idx, src.shape(i));
|
||||
|
||||
// Handle negative indices
|
||||
start = (start < 0) ? start + src.shape(i) : start;
|
||||
end = (end < 0) ? end + src.shape(i) : end;
|
||||
|
||||
gather_indices.push_back(arange(start, end, stride, uint32));
|
||||
num_slices++;
|
||||
is_slice[i] = true;
|
||||
|
@ -727,6 +727,11 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
np.array_equal(a_np[idx_np, idx_np], np.array(a_mlx[idx_mlx, idx_mlx]))
|
||||
)
|
||||
|
||||
# Slicing with negative indices and integer
|
||||
a_np = np.arange(10).reshape(5, 2)
|
||||
a_mlx = mx.array(a_np)
|
||||
self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0])))
|
||||
|
||||
def test_setitem(self):
|
||||
a = mx.array(0)
|
||||
a[None] = 1
|
||||
|
Loading…
Reference in New Issue
Block a user