Indexing bug fix (#947)

* Fix axes accounting

* Add tests
This commit is contained in:
Jagrit Digani
2024-04-01 12:18:50 -07:00
committed by GitHub
parent 02fedbf1da
commit 639e06e1f3
2 changed files with 23 additions and 2 deletions

View File

@@ -1210,6 +1210,23 @@ class TestArray(mlx_tests.MLXTestCase):
np.array([1, 3]),
)
check_slices(
np.zeros((3, 4, 5, 3)),
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
np.array([2, 1]),
slice(None, None, None),
slice(None, None, 2),
slice(None, None, None),
)
check_slices(
np.zeros((3, 4, 5, 3)),
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
np.array([2, 1]),
slice(None, None, None),
slice(None, None, 2),
)
def test_array_at(self):
a = mx.array(1)
a = a.at[None].add(1)