fix scatter index bug (#1122)

This commit is contained in:
Awni Hannun
2024-05-14 15:04:58 -07:00
committed by GitHub
parent 56a4eaed72
commit 631dfbe673
2 changed files with 7 additions and 2 deletions

View File

@@ -540,12 +540,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
// Get slice size
int slice_size = (end - start);
// Broadcast update to slide size
// Broadcast update to slice size
std::vector<int> up_shape_broadcast = {1, slice_size};
up_shape_broadcast.insert(
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(update, up_shape_broadcast);
up = broadcast_to(up, up_shape_broadcast);
auto indices = std::vector<array>{idx};
auto axes = std::vector<int>{0};