Indexing bug (#233)

* fix

* test
This commit is contained in:
Awni Hannun
2023-12-20 10:44:01 -08:00
committed by GitHub
parent 2807c6aff0
commit f40d17047d
2 changed files with 10 additions and 0 deletions

View File

@@ -727,6 +727,11 @@ class TestArray(mlx_tests.MLXTestCase):
np.array_equal(a_np[idx_np, idx_np], np.array(a_mlx[idx_mlx, idx_mlx]))
)
# Slicing with negative indices and integer
a_np = np.arange(10).reshape(5, 2)
a_mlx = mx.array(a_np)
self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0])))
def test_setitem(self):
a = mx.array(0)
a[None] = 1