mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
list based indexing (#1150)
This commit is contained in:
@@ -1740,6 +1740,68 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
y = np.from_dlpack(x)
|
||||
self.assertTrue(mx.array_equal(y, x))
|
||||
|
||||
def test_getitem_with_list(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
idx = [0, 2, 4]
|
||||
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
|
||||
|
||||
a = mx.array([[1, 2], [3, 4], [5, 6]])
|
||||
idx = [0, 2]
|
||||
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
|
||||
|
||||
a = mx.arange(10).reshape(5, 2)
|
||||
idx = [0, 2, 4]
|
||||
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
|
||||
|
||||
idx = [0, 2]
|
||||
a = mx.arange(16).reshape(4, 4)
|
||||
anp = np.array(a)
|
||||
self.assertTrue(np.array_equal(a[idx, 0], anp[idx, 0]))
|
||||
self.assertTrue(np.array_equal(a[idx, :], anp[idx, :]))
|
||||
self.assertTrue(np.array_equal(a[0, idx], anp[0, idx]))
|
||||
self.assertTrue(np.array_equal(a[:, idx], anp[:, idx]))
|
||||
|
||||
def test_setitem_with_list(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
anp = np.array(a)
|
||||
idx = [0, 2, 4]
|
||||
a[idx] = 3
|
||||
anp[idx] = 3
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
a = mx.array([[1, 2], [3, 4], [5, 6]])
|
||||
idx = [0, 2]
|
||||
anp = np.array(a)
|
||||
a[idx] = 3
|
||||
anp[idx] = 3
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
a = mx.arange(10).reshape(5, 2)
|
||||
idx = [0, 2, 4]
|
||||
anp = np.array(a)
|
||||
a[idx] = 3
|
||||
anp[idx] = 3
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
idx = [0, 2]
|
||||
a = mx.arange(16).reshape(4, 4)
|
||||
anp = np.array(a)
|
||||
a[idx, 0] = 1
|
||||
anp[idx, 0] = 1
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
a[idx, :] = 2
|
||||
anp[idx, :] = 2
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
a[0, idx] = 3
|
||||
anp[0, idx] = 3
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
a[:, idx] = 4
|
||||
anp[:, idx] = 4
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user