fix slice update indexing (#1053)

This commit is contained in:
Awni Hannun
2024-04-29 12:17:40 -07:00
committed by GitHub
parent 490c0c4fdc
commit 09f1777896
4 changed files with 38 additions and 23 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <numeric>
#include <sstream>
@@ -767,6 +766,14 @@ auto mlx_slice_update(
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj))) {
return std::make_pair(false, src);
}
if (nb::isinstance<nb::tuple>(obj)) {
// Can't route to slice update if any arrays are present
for (auto idx : nb::cast<nb::tuple>(obj)) {
if (nb::isinstance<array>(idx)) {
return std::make_pair(false, src);
}
}
}
// Should be able to route to slice update
@@ -804,14 +811,6 @@ auto mlx_slice_update(
// It must be a tuple
auto entries = nb::cast<nb::tuple>(obj);
// Can't route to slice update if any arrays are present
for (int i = 0; i < entries.size(); i++) {
auto idx = entries[i];
if (nb::isinstance<array>(idx)) {
return std::make_pair(false, src);
}
}
// Expand ellipses into a series of ':' slices
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
@@ -828,9 +827,19 @@ auto mlx_slice_update(
}
// Process entries
std::vector<int> upd_expand_dims;
int ax = 0;
for (int i = 0; i < indices.size(); ++i) {
std::vector<int> up_reshape(src.ndim());
int ax = src.ndim() - 1;
int up_ax = up.ndim() - 1;
for (; ax >= non_none_indices; ax--) {
if (up_ax >= 0) {
up_reshape[ax] = up.shape(up_ax);
up_ax--;
} else {
up_reshape[ax] = 1;
}
}
for (int i = indices.size() - 1; i >= 0; --i) {
auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) {
get_slice_params(
@@ -839,18 +848,19 @@ auto mlx_slice_update(
strides[ax],
nb::cast<nb::slice>(pyidx),
src.shape(ax));
ax++;
up_reshape[ax] = (up_ax >= 0) ? up.shape(up_ax--) : 1;
ax--;
} else if (nb::isinstance<nb::int_>(pyidx)) {
int st = nb::cast<int>(pyidx);
st = (st < 0) ? st + src.shape(ax) : st;
starts[ax] = st;
stops[ax] = st + 1;
upd_expand_dims.push_back(ax);
ax++;
up_reshape[ax] = 1;
ax--;
}
}
up = expand_dims(up, upd_expand_dims);
up = reshape(up, std::move(up_reshape));
auto out = slice_update(src, up, starts, stops, strides);
return std::make_pair(true, out);
}

View File

@@ -1262,6 +1262,14 @@ class TestArray(mlx_tests.MLXTestCase):
np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4)
)
x = mx.zeros((2, 3, 4, 5, 3))
x[..., 0] = 1.0
self.assertTrue(mx.array_equal(x[..., 0], mx.ones((2, 3, 4, 5))))
x = mx.zeros((2, 3, 4, 5, 3))
x[:, 0] = 1.0
self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3))))
def test_array_at(self):
a = mx.array(1)
a = a.at[None].add(1)