mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Fix Split::vmap (#1845)
This commit is contained in:

committed by
GitHub

parent
1c0c118f7c
commit
9eb7d7362f
@@ -596,6 +596,18 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
def test_vmap_split_vmap(self):
|
||||
def fun(x):
|
||||
a, b = mx.split(x, 2, 1)
|
||||
return mx.concatenate([b, a], 1)
|
||||
|
||||
x = mx.ones((5, 6, 7))
|
||||
y = mx.ones((5, 4, 6, 7))
|
||||
fx = fun(x)
|
||||
fy = mx.vmap(fun, in_axes=1)(y)
|
||||
self.assertEqual(fx.shape, (5, 6, 7))
|
||||
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user