mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 20:41:13 +08:00
setite negative indexing bug (#189)
This commit is contained in:
parent
dc2edc762c
commit
104c34f906
@ -41,9 +41,6 @@ void get_slice_params(
|
||||
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
|
||||
ends = get_slice_int(
|
||||
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||
|
||||
// starts = (starts < 0) ? starts + axis_size : starts;
|
||||
// ends = (ends < 0) ? ends + axis_size : ends;
|
||||
}
|
||||
|
||||
array get_int_index(py::object idx, int axis_size) {
|
||||
@ -568,7 +565,13 @@ array mlx_set_item_nd(
|
||||
auto& pyidx = indices[i];
|
||||
if (py::isinstance<py::slice>(pyidx)) {
|
||||
int start, end, stride;
|
||||
get_slice_params(start, end, stride, pyidx, src.shape(ax++));
|
||||
auto axis_size = src.shape(ax++);
|
||||
get_slice_params(start, end, stride, pyidx, axis_size);
|
||||
|
||||
// Handle negative indices
|
||||
start = (start < 0) ? start + axis_size : start;
|
||||
end = (end < 0) ? end + axis_size : end;
|
||||
|
||||
auto idx = arange(start, end, stride, uint32);
|
||||
std::vector<int> idx_shape(max_dim + num_slices, 1);
|
||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||
|
@ -903,6 +903,11 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
np.array([0, 1]),
|
||||
)
|
||||
|
||||
# Check slice assign with negative indices works
|
||||
a = mx.zeros((5, 5), mx.int32)
|
||||
a[2:-2, 2:-2] = 4
|
||||
self.assertEquals(a[2, 2].item(), 4)
|
||||
|
||||
def test_slice_negative_step(self):
|
||||
a_np = np.arange(20)
|
||||
a_mx = mx.array(a_np)
|
||||
|
Loading…
Reference in New Issue
Block a user