From 631dfbe67309fb630795cd612739cbe54c75e222 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 14 May 2024 15:04:58 -0700 Subject: [PATCH] fix scatter index bug (#1122) --- python/src/indexing.cpp | 4 ++-- python/tests/test_array.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 384aa7c84..f4f9fc053 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -540,12 +540,12 @@ std::tuple, array, std::vector> mlx_scatter_args_slice( // Get slice size int slice_size = (end - start); - // Broadcast update to slide size + // Broadcast update to slice size std::vector up_shape_broadcast = {1, slice_size}; up_shape_broadcast.insert( up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end()); - up = broadcast_to(update, up_shape_broadcast); + up = broadcast_to(up, up_shape_broadcast); auto indices = std::vector{idx}; auto axes = std::vector{0}; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index fc97a833f..f79fad73e 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1318,6 +1318,11 @@ class TestArray(mlx_tests.MLXTestCase): a = a.at[idx_x, :, 0].minimum(update) self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update)) + update = mx.array([1.0, 2.0])[None, None, None] + src = mx.array([1.0, 2.0])[None, :] + src = src.at[0:1].add(update) + self.assertTrue(mx.array_equal(src, mx.array([[2.0, 4.0]]))) + def test_slice_negative_step(self): a_np = np.arange(20) a_mx = mx.array(a_np)