list based indexing (#1150)

This commit is contained in:
Awni Hannun
2024-05-22 15:52:05 -07:00
committed by GitHub
parent 79ef49b2c2
commit eb8321d863
5 changed files with 425 additions and 328 deletions

View File

@@ -1740,6 +1740,68 @@ class TestArray(mlx_tests.MLXTestCase):
y = np.from_dlpack(x)
self.assertTrue(mx.array_equal(y, x))
def test_getitem_with_list(self):
a = mx.array([1, 2, 3, 4, 5])
idx = [0, 2, 4]
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
a = mx.array([[1, 2], [3, 4], [5, 6]])
idx = [0, 2]
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
a = mx.arange(10).reshape(5, 2)
idx = [0, 2, 4]
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
idx = [0, 2]
a = mx.arange(16).reshape(4, 4)
anp = np.array(a)
self.assertTrue(np.array_equal(a[idx, 0], anp[idx, 0]))
self.assertTrue(np.array_equal(a[idx, :], anp[idx, :]))
self.assertTrue(np.array_equal(a[0, idx], anp[0, idx]))
self.assertTrue(np.array_equal(a[:, idx], anp[:, idx]))
def test_setitem_with_list(self):
a = mx.array([1, 2, 3, 4, 5])
anp = np.array(a)
idx = [0, 2, 4]
a[idx] = 3
anp[idx] = 3
self.assertTrue(np.array_equal(a, anp))
a = mx.array([[1, 2], [3, 4], [5, 6]])
idx = [0, 2]
anp = np.array(a)
a[idx] = 3
anp[idx] = 3
self.assertTrue(np.array_equal(a, anp))
a = mx.arange(10).reshape(5, 2)
idx = [0, 2, 4]
anp = np.array(a)
a[idx] = 3
anp[idx] = 3
self.assertTrue(np.array_equal(a, anp))
idx = [0, 2]
a = mx.arange(16).reshape(4, 4)
anp = np.array(a)
a[idx, 0] = 1
anp[idx, 0] = 1
self.assertTrue(np.array_equal(a, anp))
a[idx, :] = 2
anp[idx, :] = 2
self.assertTrue(np.array_equal(a, anp))
a[0, idx] = 3
anp[0, idx] = 3
self.assertTrue(np.array_equal(a, anp))
a[:, idx] = 4
anp[:, idx] = 4
self.assertTrue(np.array_equal(a, anp))
if __name__ == "__main__":
unittest.main()