mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
allow take to work with integer index (#1440)
This commit is contained in:
@@ -1059,6 +1059,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
# Take with integer index
|
||||
a = mx.arange(8).reshape(2, 4)
|
||||
out = mx.take(a, 1, axis=0)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6, 7])))
|
||||
out = mx.take(a, 1, axis=1)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
|
||||
|
||||
def test_take_along_axis(self):
|
||||
a_np = np.arange(8).reshape(2, 2, 2)
|
||||
a_mlx = mx.array(a_np)
|
||||
|
Reference in New Issue
Block a user