Fix concatenate vmap (#1600)

This commit is contained in:
Angelos Katharopoulos
2024-11-19 10:44:04 -08:00
committed by GitHub
parent 2af7e8a9a6
commit 5e89aace9b
2 changed files with 43 additions and 11 deletions

View File

@@ -527,6 +527,26 @@ class TestVmap(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
def test_vmap_concatenate(self):
x = mx.random.uniform(shape=(2, 2, 2))
def cat_fun(x, y):
return mx.concatenate([x, y], axis=1)
def cat_constant(x):
y = mx.ones((2, 1))
return mx.concatenate([x, y], 1)
out = mx.vmap(cat_fun, in_axes=(0, 2))(x, x)
target = mx.stack(
[mx.concatenate([x[i], x[:, :, i]], axis=1) for i in range(2)]
)
self.assertTrue(mx.array_equal(out, target))
out = mx.vmap(cat_constant)(x)
target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2)
self.assertTrue(mx.array_equal(out, target))
if __name__ == "__main__":
unittest.main()