fix scatter index bug (#1122)

This commit is contained in:
Awni Hannun
2024-05-14 15:04:58 -07:00
committed by GitHub
parent 56a4eaed72
commit 631dfbe673
2 changed files with 7 additions and 2 deletions

View File

@@ -1318,6 +1318,11 @@ class TestArray(mlx_tests.MLXTestCase):
a = a.at[idx_x, :, 0].minimum(update)
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
update = mx.array([1.0, 2.0])[None, None, None]
src = mx.array([1.0, 2.0])[None, :]
src = src.at[0:1].add(update)
self.assertTrue(mx.array_equal(src, mx.array([[2.0, 4.0]])))
def test_slice_negative_step(self):
a_np = np.arange(20)
a_mx = mx.array(a_np)