No extra reshape (#1557)

* no extra reshape

* lint
This commit is contained in:
Awni Hannun 2024-11-02 19:07:20 -07:00 committed by GitHub
parent 46d8b16ab4
commit 09bc32f62f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -803,9 +803,10 @@ auto mlx_slice_update(
// Pre process tuple
auto upd = to_array(v, src.dtype());
// Remove leading singletons dimensions from the update
// Remove extra leading singletons dimensions from the update
int s = 0;
for (; s < upd.ndim() && upd.shape(s) == 1; s++) {
for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim();
s++) {
};
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
up_shape = up_shape.empty() ? std::vector{1} : up_shape;