diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index c446ff948..93b9d480e 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -14,6 +14,10 @@ std::tuple prepare_slice( data_offset += start_indices[i] * in.strides()[i]; inp_strides[i] = in.strides()[i] * strides[i]; } + // Normalize the offset + if (data_offset < 0) { + data_offset += in.data_size(); + } return std::make_tuple(data_offset, inp_strides); } @@ -54,9 +58,10 @@ void slice( data_end += end_idx * in.strides()[i]; } } - // data_end can be -1 - size_t data_size = - data_end < 0 ? (data_offset - data_end) : (data_end - data_offset); + if (data_end < 0) { + data_end += in.data_size(); + } + size_t data_size = (data_end - data_offset); shared_buffer_slice(in, inp_strides, data_offset, data_size, out); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index f5078afc0..515836a5e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2846,6 +2846,11 @@ class TestOps(mlx_tests.MLXTestCase): b[::2] = 0 self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1]))) + def test_slice_with_negative_stride(self): + a = mx.random.uniform(shape=(128, 4)) + out = a[::-1] + self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) + if __name__ == "__main__": unittest.main()