fix scatter index bug (#1122)

This commit is contained in:
Awni Hannun 2024-05-14 15:04:58 -07:00 committed by GitHub
parent 56a4eaed72
commit 631dfbe673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 2 deletions

View File

@ -540,12 +540,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
// Get slice size
int slice_size = (end - start);
// Broadcast update to slide size
// Broadcast update to slice size
std::vector<int> up_shape_broadcast = {1, slice_size};
up_shape_broadcast.insert(
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(update, up_shape_broadcast);
up = broadcast_to(up, up_shape_broadcast);
auto indices = std::vector<array>{idx};
auto axes = std::vector<int>{0};

View File

@ -1318,6 +1318,11 @@ class TestArray(mlx_tests.MLXTestCase):
a = a.at[idx_x, :, 0].minimum(update)
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
update = mx.array([1.0, 2.0])[None, None, None]
src = mx.array([1.0, 2.0])[None, :]
src = src.at[0:1].add(update)
self.assertTrue(mx.array_equal(src, mx.array([[2.0, 4.0]])))
def test_slice_negative_step(self):
a_np = np.arange(20)
a_mx = mx.array(a_np)