mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix slice update indexing (#1053)
This commit is contained in:
parent
490c0c4fdc
commit
09f1777896
@ -265,8 +265,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load build metal library from source"
|
||||
<< "\n";
|
||||
msg << "[metal::Device] Unable to build metal library from source" << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
@ -285,8 +284,7 @@ MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load build stitched metal library"
|
||||
<< "\n";
|
||||
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
|
@ -426,7 +426,7 @@ array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
||||
int ax = axis < 0 ? axis + out_dim : axis;
|
||||
if (ax < 0 || ax >= out_dim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[expand_dims] Invalid axes " << axis << " for output array with "
|
||||
msg << "[expand_dims] Invalid axis " << axis << " for output array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -452,7 +452,7 @@ array expand_dims(
|
||||
ax = ax < 0 ? ax + out_ndim : ax;
|
||||
if (ax < 0 || ax >= out_ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[expand_dims] Invalid axes " << ax << " for output array with "
|
||||
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -591,7 +591,6 @@ array slice_update(
|
||||
if (!has_neg_strides && upd_shape == src.shape()) {
|
||||
return astype(update_broadcasted, src.dtype(), s);
|
||||
}
|
||||
|
||||
return array(
|
||||
src.shape(),
|
||||
src.dtype(),
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
@ -767,6 +766,14 @@ auto mlx_slice_update(
|
||||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj))) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
if (nb::isinstance<nb::tuple>(obj)) {
|
||||
// Can't route to slice update if any arrays are present
|
||||
for (auto idx : nb::cast<nb::tuple>(obj)) {
|
||||
if (nb::isinstance<array>(idx)) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should be able to route to slice update
|
||||
|
||||
@ -804,14 +811,6 @@ auto mlx_slice_update(
|
||||
// It must be a tuple
|
||||
auto entries = nb::cast<nb::tuple>(obj);
|
||||
|
||||
// Can't route to slice update if any arrays are present
|
||||
for (int i = 0; i < entries.size(); i++) {
|
||||
auto idx = entries[i];
|
||||
if (nb::isinstance<array>(idx)) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
}
|
||||
|
||||
// Expand ellipses into a series of ':' slices
|
||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||
|
||||
@ -828,9 +827,19 @@ auto mlx_slice_update(
|
||||
}
|
||||
|
||||
// Process entries
|
||||
std::vector<int> upd_expand_dims;
|
||||
int ax = 0;
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
std::vector<int> up_reshape(src.ndim());
|
||||
int ax = src.ndim() - 1;
|
||||
int up_ax = up.ndim() - 1;
|
||||
for (; ax >= non_none_indices; ax--) {
|
||||
if (up_ax >= 0) {
|
||||
up_reshape[ax] = up.shape(up_ax);
|
||||
up_ax--;
|
||||
} else {
|
||||
up_reshape[ax] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = indices.size() - 1; i >= 0; --i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
get_slice_params(
|
||||
@ -839,18 +848,19 @@ auto mlx_slice_update(
|
||||
strides[ax],
|
||||
nb::cast<nb::slice>(pyidx),
|
||||
src.shape(ax));
|
||||
ax++;
|
||||
up_reshape[ax] = (up_ax >= 0) ? up.shape(up_ax--) : 1;
|
||||
ax--;
|
||||
} else if (nb::isinstance<nb::int_>(pyidx)) {
|
||||
int st = nb::cast<int>(pyidx);
|
||||
st = (st < 0) ? st + src.shape(ax) : st;
|
||||
starts[ax] = st;
|
||||
stops[ax] = st + 1;
|
||||
upd_expand_dims.push_back(ax);
|
||||
ax++;
|
||||
up_reshape[ax] = 1;
|
||||
ax--;
|
||||
}
|
||||
}
|
||||
|
||||
up = expand_dims(up, upd_expand_dims);
|
||||
up = reshape(up, std::move(up_reshape));
|
||||
auto out = slice_update(src, up, starts, stops, strides);
|
||||
return std::make_pair(true, out);
|
||||
}
|
||||
|
@ -1262,6 +1262,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4)
|
||||
)
|
||||
|
||||
x = mx.zeros((2, 3, 4, 5, 3))
|
||||
x[..., 0] = 1.0
|
||||
self.assertTrue(mx.array_equal(x[..., 0], mx.ones((2, 3, 4, 5))))
|
||||
|
||||
x = mx.zeros((2, 3, 4, 5, 3))
|
||||
x[:, 0] = 1.0
|
||||
self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3))))
|
||||
|
||||
def test_array_at(self):
|
||||
a = mx.array(1)
|
||||
a = a.at[None].add(1)
|
||||
|
Loading…
Reference in New Issue
Block a user