mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
fix scatter index bug (#1122)
This commit is contained in:
@@ -540,12 +540,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
// Get slice size
|
||||
int slice_size = (end - start);
|
||||
|
||||
// Broadcast update to slide size
|
||||
// Broadcast update to slice size
|
||||
std::vector<int> up_shape_broadcast = {1, slice_size};
|
||||
up_shape_broadcast.insert(
|
||||
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
|
||||
|
||||
up = broadcast_to(update, up_shape_broadcast);
|
||||
up = broadcast_to(up, up_shape_broadcast);
|
||||
|
||||
auto indices = std::vector<array>{idx};
|
||||
auto axes = std::vector<int>{0};
|
||||
|
Reference in New Issue
Block a user