mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
fix slice update indexing (#1053)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
@@ -767,6 +766,14 @@ auto mlx_slice_update(
|
||||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj))) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
if (nb::isinstance<nb::tuple>(obj)) {
|
||||
// Can't route to slice update if any arrays are present
|
||||
for (auto idx : nb::cast<nb::tuple>(obj)) {
|
||||
if (nb::isinstance<array>(idx)) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should be able to route to slice update
|
||||
|
||||
@@ -804,14 +811,6 @@ auto mlx_slice_update(
|
||||
// It must be a tuple
|
||||
auto entries = nb::cast<nb::tuple>(obj);
|
||||
|
||||
// Can't route to slice update if any arrays are present
|
||||
for (int i = 0; i < entries.size(); i++) {
|
||||
auto idx = entries[i];
|
||||
if (nb::isinstance<array>(idx)) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
}
|
||||
|
||||
// Expand ellipses into a series of ':' slices
|
||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||
|
||||
@@ -828,9 +827,19 @@ auto mlx_slice_update(
|
||||
}
|
||||
|
||||
// Process entries
|
||||
std::vector<int> upd_expand_dims;
|
||||
int ax = 0;
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
std::vector<int> up_reshape(src.ndim());
|
||||
int ax = src.ndim() - 1;
|
||||
int up_ax = up.ndim() - 1;
|
||||
for (; ax >= non_none_indices; ax--) {
|
||||
if (up_ax >= 0) {
|
||||
up_reshape[ax] = up.shape(up_ax);
|
||||
up_ax--;
|
||||
} else {
|
||||
up_reshape[ax] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = indices.size() - 1; i >= 0; --i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
get_slice_params(
|
||||
@@ -839,18 +848,19 @@ auto mlx_slice_update(
|
||||
strides[ax],
|
||||
nb::cast<nb::slice>(pyidx),
|
||||
src.shape(ax));
|
||||
ax++;
|
||||
up_reshape[ax] = (up_ax >= 0) ? up.shape(up_ax--) : 1;
|
||||
ax--;
|
||||
} else if (nb::isinstance<nb::int_>(pyidx)) {
|
||||
int st = nb::cast<int>(pyidx);
|
||||
st = (st < 0) ? st + src.shape(ax) : st;
|
||||
starts[ax] = st;
|
||||
stops[ax] = st + 1;
|
||||
upd_expand_dims.push_back(ax);
|
||||
ax++;
|
||||
up_reshape[ax] = 1;
|
||||
ax--;
|
||||
}
|
||||
}
|
||||
|
||||
up = expand_dims(up, upd_expand_dims);
|
||||
up = reshape(up, std::move(up_reshape));
|
||||
auto out = slice_update(src, up, starts, stops, strides);
|
||||
return std::make_pair(true, out);
|
||||
}
|
||||
|
@@ -1262,6 +1262,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4)
|
||||
)
|
||||
|
||||
x = mx.zeros((2, 3, 4, 5, 3))
|
||||
x[..., 0] = 1.0
|
||||
self.assertTrue(mx.array_equal(x[..., 0], mx.ones((2, 3, 4, 5))))
|
||||
|
||||
x = mx.zeros((2, 3, 4, 5, 3))
|
||||
x[:, 0] = 1.0
|
||||
self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3))))
|
||||
|
||||
def test_array_at(self):
|
||||
a = mx.array(1)
|
||||
a = a.at[None].add(1)
|
||||
|
Reference in New Issue
Block a user