setite negative indexing bug (#189)

This commit is contained in:
Awni Hannun 2023-12-16 06:44:47 -08:00 committed by GitHub
parent dc2edc762c
commit 104c34f906
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 4 deletions

View File

@ -41,9 +41,6 @@ void get_slice_params(
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
ends = get_slice_int(
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
// starts = (starts < 0) ? starts + axis_size : starts;
// ends = (ends < 0) ? ends + axis_size : ends;
}
array get_int_index(py::object idx, int axis_size) {
@ -568,7 +565,13 @@ array mlx_set_item_nd(
auto& pyidx = indices[i];
if (py::isinstance<py::slice>(pyidx)) {
int start, end, stride;
get_slice_params(start, end, stride, pyidx, src.shape(ax++));
auto axis_size = src.shape(ax++);
get_slice_params(start, end, stride, pyidx, axis_size);
// Handle negative indices
start = (start < 0) ? start + axis_size : start;
end = (end < 0) ? end + axis_size : end;
auto idx = arange(start, end, stride, uint32);
std::vector<int> idx_shape(max_dim + num_slices, 1);
auto loc = slice_num + (arrays_first ? max_dim : 0);

View File

@ -903,6 +903,11 @@ class TestArray(mlx_tests.MLXTestCase):
np.array([0, 1]),
)
# Check slice assign with negative indices works
a = mx.zeros((5, 5), mx.int32)
a[2:-2, 2:-2] = 4
self.assertEquals(a[2, 2].item(), 4)
def test_slice_negative_step(self):
a_np = np.arange(20)
a_mx = mx.array(a_np)