mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
46d8b16ab4
commit
09bc32f62f
@ -803,9 +803,10 @@ auto mlx_slice_update(
|
|||||||
// Pre process tuple
|
// Pre process tuple
|
||||||
auto upd = to_array(v, src.dtype());
|
auto upd = to_array(v, src.dtype());
|
||||||
|
|
||||||
// Remove leading singletons dimensions from the update
|
// Remove extra leading singletons dimensions from the update
|
||||||
int s = 0;
|
int s = 0;
|
||||||
for (; s < upd.ndim() && upd.shape(s) == 1; s++) {
|
for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim();
|
||||||
|
s++) {
|
||||||
};
|
};
|
||||||
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
|
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
|
||||||
up_shape = up_shape.empty() ? std::vector{1} : up_shape;
|
up_shape = up_shape.empty() ? std::vector{1} : up_shape;
|
||||||
|
Loading…
Reference in New Issue
Block a user