Update set item (#861)

* Update mlx_set_item to handle regular slices without expanding

* Refactor ellipsis handling

* Route mlx_set_item to slice_update where possible

* Update mlx_scatter_args_slice

* Don't route to gather if no array indices
This commit is contained in:
Jagrit Digani
2024-03-21 02:48:13 -07:00
committed by GitHub
parent e849b3424a
commit a5681ebc52
2 changed files with 308 additions and 123 deletions

View File

@@ -558,8 +558,7 @@ array slice_update(
normalize_slice(src.shape(), start, stop, strides);
// Broadcast update shape to slice shape
auto upd_shape_broadcast = broadcast_shapes(upd_shape, update.shape());
auto update_broadcasted = broadcast_to(update, upd_shape_broadcast, s);
auto update_broadcasted = broadcast_to(update, upd_shape, s);
// If the entire src is the slice, just return the update
if (!has_neg_strides && upd_shape == src.shape()) {
@@ -571,7 +570,7 @@ array slice_update(
src.dtype(),
std::make_unique<SliceUpdate>(
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
{src, update});
{src, update_broadcasted});
}
/** Update a slice from the source array with stride 1 in each dimension */