Set item bug fix (#879)

* set item shaping bug fix

* Add extra tests
This commit is contained in:
Jagrit Digani
2024-03-22 12:11:17 -07:00
committed by GitHub
parent fcda3a0e66
commit 8e5a5a1ccd
2 changed files with 51 additions and 12 deletions

View File

@@ -1058,6 +1058,29 @@ class TestArray(mlx_tests.MLXTestCase):
a[2:-2, 2:-2] = 4
self.assertEqual(a[2, 2].item(), 4)
# Check slice array slice
check_slices(
np.zeros((5, 4, 4)),
np.arange(4 * 2 * 3).reshape(4, 2, 3),
slice(0, 4),
np.array([1, 3]),
slice(None, -1),
)
check_slices(
np.zeros((5, 4, 4)),
np.arange(4 * 2 * 2).reshape(4, 2, 2),
slice(0, 4),
np.array([1, 3]),
slice(0, 4, 2),
)
check_slices(
np.zeros((1, 10, 4)),
np.arange(2 * 4).reshape(1, 2, 4),
slice(None, None, None),
np.array([1, 3]),
)
def test_array_at(self):
a = mx.array(1)
a = a.at[None].add(1)