mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
fix scatter index bug (#1122)
This commit is contained in:
@@ -1318,6 +1318,11 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a = a.at[idx_x, :, 0].minimum(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
|
||||
|
||||
update = mx.array([1.0, 2.0])[None, None, None]
|
||||
src = mx.array([1.0, 2.0])[None, :]
|
||||
src = src.at[0:1].add(update)
|
||||
self.assertTrue(mx.array_equal(src, mx.array([[2.0, 4.0]])))
|
||||
|
||||
def test_slice_negative_step(self):
|
||||
a_np = np.arange(20)
|
||||
a_mx = mx.array(a_np)
|
||||
|
Reference in New Issue
Block a user