mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
@@ -163,6 +163,61 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
|
||||
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||
|
||||
def test_vmap_indexing(self):
|
||||
x = mx.arange(16).reshape(2, 2, 2, 2)
|
||||
inds = mx.array([[0, 1, 0], [1, 1, 0]])
|
||||
|
||||
out = mx.vmap(lambda x, y: x[y], in_axes=(0, 0))(x, inds)
|
||||
expected = mx.array(
|
||||
[
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
|
||||
[[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
|
||||
]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
out = mx.vmap(lambda x, y: x[y], in_axes=(0, None))(x, inds)
|
||||
expected = mx.array(
|
||||
[
|
||||
[
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
|
||||
[[[4, 5], [6, 7]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],
|
||||
],
|
||||
[
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
|
||||
[[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],
|
||||
],
|
||||
]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
out = mx.vmap(lambda x, y: x[y], in_axes=(None, 0))(x, inds)
|
||||
expected = mx.array(
|
||||
[
|
||||
[
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
|
||||
],
|
||||
[
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]]],
|
||||
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
|
||||
],
|
||||
]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
inds2 = mx.array([[0, 1, 0], [0, 1, 0]])
|
||||
out = mx.vmap(lambda x, y, z: x[y, z], in_axes=(None, 0, 0))(x, inds, inds2)
|
||||
expected = mx.array(
|
||||
[
|
||||
[[[0, 1], [2, 3]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],
|
||||
[[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],
|
||||
]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user