[Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035)

* Not sure if this is correct

* Format

* Edit tests

* Add negative test

* Format

* add one more test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jacket
2024-04-29 09:57:28 -05:00
committed by GitHub
parent c4a471c99d
commit 490c0c4fdc
2 changed files with 12 additions and 3 deletions

View File

@@ -845,9 +845,7 @@ auto mlx_slice_update(
st = (st < 0) ? st + src.shape(ax) : st;
starts[ax] = st;
stops[ax] = st + 1;
if (src.ndim() - ax < up.ndim()) {
upd_expand_dims.push_back(ax - src.ndim());
}
upd_expand_dims.push_back(ax);
ax++;
}
}