allow take to work with integer index (#1440)

This commit is contained in:
Awni Hannun
2024-09-26 15:58:03 -07:00
committed by GitHub
parent 5b6f38df2b
commit 718aea3f1d
4 changed files with 74 additions and 17 deletions

View File

@@ -1059,6 +1059,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
# Take with integer index
a = mx.arange(8).reshape(2, 4)
out = mx.take(a, 1, axis=0)
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6, 7])))
out = mx.take(a, 1, axis=1)
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
def test_take_along_axis(self):
a_np = np.arange(8).reshape(2, 2, 2)
a_mlx = mx.array(a_np)