mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
setite negative indexing bug (#189)
This commit is contained in:
parent
dc2edc762c
commit
104c34f906
@ -41,9 +41,6 @@ void get_slice_params(
|
|||||||
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
|
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
|
||||||
ends = get_slice_int(
|
ends = get_slice_int(
|
||||||
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||||
|
|
||||||
// starts = (starts < 0) ? starts + axis_size : starts;
|
|
||||||
// ends = (ends < 0) ? ends + axis_size : ends;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array get_int_index(py::object idx, int axis_size) {
|
array get_int_index(py::object idx, int axis_size) {
|
||||||
@ -568,7 +565,13 @@ array mlx_set_item_nd(
|
|||||||
auto& pyidx = indices[i];
|
auto& pyidx = indices[i];
|
||||||
if (py::isinstance<py::slice>(pyidx)) {
|
if (py::isinstance<py::slice>(pyidx)) {
|
||||||
int start, end, stride;
|
int start, end, stride;
|
||||||
get_slice_params(start, end, stride, pyidx, src.shape(ax++));
|
auto axis_size = src.shape(ax++);
|
||||||
|
get_slice_params(start, end, stride, pyidx, axis_size);
|
||||||
|
|
||||||
|
// Handle negative indices
|
||||||
|
start = (start < 0) ? start + axis_size : start;
|
||||||
|
end = (end < 0) ? end + axis_size : end;
|
||||||
|
|
||||||
auto idx = arange(start, end, stride, uint32);
|
auto idx = arange(start, end, stride, uint32);
|
||||||
std::vector<int> idx_shape(max_dim + num_slices, 1);
|
std::vector<int> idx_shape(max_dim + num_slices, 1);
|
||||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||||
|
@ -903,6 +903,11 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
np.array([0, 1]),
|
np.array([0, 1]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check slice assign with negative indices works
|
||||||
|
a = mx.zeros((5, 5), mx.int32)
|
||||||
|
a[2:-2, 2:-2] = 4
|
||||||
|
self.assertEquals(a[2, 2].item(), 4)
|
||||||
|
|
||||||
def test_slice_negative_step(self):
|
def test_slice_negative_step(self):
|
||||||
a_np = np.arange(20)
|
a_np = np.arange(20)
|
||||||
a_mx = mx.array(a_np)
|
a_mx = mx.array(a_np)
|
||||||
|
Loading…
Reference in New Issue
Block a user