fix slice update indexing (#1053)

This commit is contained in:
Awni Hannun 2024-04-29 12:17:40 -07:00 committed by GitHub
parent 490c0c4fdc
commit 09f1777896
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 38 additions and 23 deletions

View File

@ -265,8 +265,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
// Throw error if unable to compile library // Throw error if unable to compile library
if (!mtl_lib) { if (!mtl_lib) {
std::ostringstream msg; std::ostringstream msg;
msg << "[metal::Device] Unable to load build metal library from source" msg << "[metal::Device] Unable to build metal library from source" << "\n";
<< "\n";
if (error) { if (error) {
msg << error->localizedDescription()->utf8String() << "\n"; msg << error->localizedDescription()->utf8String() << "\n";
} }
@ -285,8 +284,7 @@ MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
// Throw error if unable to compile library // Throw error if unable to compile library
if (!mtl_lib) { if (!mtl_lib) {
std::ostringstream msg; std::ostringstream msg;
msg << "[metal::Device] Unable to load build stitched metal library" msg << "[metal::Device] Unable to build stitched metal library" << "\n";
<< "\n";
if (error) { if (error) {
msg << error->localizedDescription()->utf8String() << "\n"; msg << error->localizedDescription()->utf8String() << "\n";
} }

View File

@ -426,7 +426,7 @@ array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
int ax = axis < 0 ? axis + out_dim : axis; int ax = axis < 0 ? axis + out_dim : axis;
if (ax < 0 || ax >= out_dim) { if (ax < 0 || ax >= out_dim) {
std::ostringstream msg; std::ostringstream msg;
msg << "[expand_dims] Invalid axes " << axis << " for output array with " msg << "[expand_dims] Invalid axis " << axis << " for output array with "
<< a.ndim() << " dimensions."; << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -452,7 +452,7 @@ array expand_dims(
ax = ax < 0 ? ax + out_ndim : ax; ax = ax < 0 ? ax + out_ndim : ax;
if (ax < 0 || ax >= out_ndim) { if (ax < 0 || ax >= out_ndim) {
std::ostringstream msg; std::ostringstream msg;
msg << "[expand_dims] Invalid axes " << ax << " for output array with " msg << "[expand_dims] Invalid axis " << ax << " for output array with "
<< a.ndim() << " dimensions."; << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -591,7 +591,6 @@ array slice_update(
if (!has_neg_strides && upd_shape == src.shape()) { if (!has_neg_strides && upd_shape == src.shape()) {
return astype(update_broadcasted, src.dtype(), s); return astype(update_broadcasted, src.dtype(), s);
} }
return array( return array(
src.shape(), src.shape(),
src.dtype(), src.dtype(),

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
@ -767,6 +766,14 @@ auto mlx_slice_update(
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj))) { (!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj))) {
return std::make_pair(false, src); return std::make_pair(false, src);
} }
if (nb::isinstance<nb::tuple>(obj)) {
// Can't route to slice update if any arrays are present
for (auto idx : nb::cast<nb::tuple>(obj)) {
if (nb::isinstance<array>(idx)) {
return std::make_pair(false, src);
}
}
}
// Should be able to route to slice update // Should be able to route to slice update
@ -804,14 +811,6 @@ auto mlx_slice_update(
// It must be a tuple // It must be a tuple
auto entries = nb::cast<nb::tuple>(obj); auto entries = nb::cast<nb::tuple>(obj);
// Can't route to slice update if any arrays are present
for (int i = 0; i < entries.size(); i++) {
auto idx = entries[i];
if (nb::isinstance<array>(idx)) {
return std::make_pair(false, src);
}
}
// Expand ellipses into a series of ':' slices // Expand ellipses into a series of ':' slices
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
@ -828,9 +827,19 @@ auto mlx_slice_update(
} }
// Process entries // Process entries
std::vector<int> upd_expand_dims; std::vector<int> up_reshape(src.ndim());
int ax = 0; int ax = src.ndim() - 1;
for (int i = 0; i < indices.size(); ++i) { int up_ax = up.ndim() - 1;
for (; ax >= non_none_indices; ax--) {
if (up_ax >= 0) {
up_reshape[ax] = up.shape(up_ax);
up_ax--;
} else {
up_reshape[ax] = 1;
}
}
for (int i = indices.size() - 1; i >= 0; --i) {
auto& pyidx = indices[i]; auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) { if (nb::isinstance<nb::slice>(pyidx)) {
get_slice_params( get_slice_params(
@ -839,18 +848,19 @@ auto mlx_slice_update(
strides[ax], strides[ax],
nb::cast<nb::slice>(pyidx), nb::cast<nb::slice>(pyidx),
src.shape(ax)); src.shape(ax));
ax++; up_reshape[ax] = (up_ax >= 0) ? up.shape(up_ax--) : 1;
ax--;
} else if (nb::isinstance<nb::int_>(pyidx)) { } else if (nb::isinstance<nb::int_>(pyidx)) {
int st = nb::cast<int>(pyidx); int st = nb::cast<int>(pyidx);
st = (st < 0) ? st + src.shape(ax) : st; st = (st < 0) ? st + src.shape(ax) : st;
starts[ax] = st; starts[ax] = st;
stops[ax] = st + 1; stops[ax] = st + 1;
upd_expand_dims.push_back(ax); up_reshape[ax] = 1;
ax++; ax--;
} }
} }
up = expand_dims(up, upd_expand_dims); up = reshape(up, std::move(up_reshape));
auto out = slice_update(src, up, starts, stops, strides); auto out = slice_update(src, up, starts, stops, strides);
return std::make_pair(true, out); return std::make_pair(true, out);
} }

View File

@ -1262,6 +1262,14 @@ class TestArray(mlx_tests.MLXTestCase):
np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4) np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4)
) )
x = mx.zeros((2, 3, 4, 5, 3))
x[..., 0] = 1.0
self.assertTrue(mx.array_equal(x[..., 0], mx.ones((2, 3, 4, 5))))
x = mx.zeros((2, 3, 4, 5, 3))
x[:, 0] = 1.0
self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3))))
def test_array_at(self): def test_array_at(self):
a = mx.array(1) a = mx.array(1)
a = a.at[None].add(1) a = a.at[None].add(1)