mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding * Refactor ellipsis handling * Route mlx_set_item to slice_update where possible * Update mlx_scatter_args_slice * Don't route to gather if no array indices
This commit is contained in:
@@ -558,8 +558,7 @@ array slice_update(
|
||||
normalize_slice(src.shape(), start, stop, strides);
|
||||
|
||||
// Broadcast update shape to slice shape
|
||||
auto upd_shape_broadcast = broadcast_shapes(upd_shape, update.shape());
|
||||
auto update_broadcasted = broadcast_to(update, upd_shape_broadcast, s);
|
||||
auto update_broadcasted = broadcast_to(update, upd_shape, s);
|
||||
|
||||
// If the entire src is the slice, just return the update
|
||||
if (!has_neg_strides && upd_shape == src.shape()) {
|
||||
@@ -571,7 +570,7 @@ array slice_update(
|
||||
src.dtype(),
|
||||
std::make_unique<SliceUpdate>(
|
||||
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
||||
{src, update});
|
||||
{src, update_broadcasted});
|
||||
}
|
||||
|
||||
/** Update a slice from the source array with stride 1 in each dimension */
|
||||
|
Reference in New Issue
Block a user