mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
Fix concatenate vmap (#1600)
This commit is contained in:

committed by
GitHub

parent
2af7e8a9a6
commit
5e89aace9b
@@ -527,6 +527,26 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
|
||||
|
||||
def test_vmap_concatenate(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 2))
|
||||
|
||||
def cat_fun(x, y):
|
||||
return mx.concatenate([x, y], axis=1)
|
||||
|
||||
def cat_constant(x):
|
||||
y = mx.ones((2, 1))
|
||||
return mx.concatenate([x, y], 1)
|
||||
|
||||
out = mx.vmap(cat_fun, in_axes=(0, 2))(x, x)
|
||||
target = mx.stack(
|
||||
[mx.concatenate([x[i], x[:, :, i]], axis=1) for i in range(2)]
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, target))
|
||||
|
||||
out = mx.vmap(cat_constant)(x)
|
||||
target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2)
|
||||
self.assertTrue(mx.array_equal(out, target))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user