fix vmap for flatten (#1955)

This commit is contained in:
Awni Hannun
2025-03-11 10:42:22 -07:00
committed by GitHub
parent 736a340478
commit 32da94507a
2 changed files with 17 additions and 1 deletions

View File

@@ -659,6 +659,16 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(mem_pre, mem_post)
def test_vmap_flatten(self):
def fun(x):
return mx.flatten(x, 0, 1)
x = mx.zeros((2, 3, 4))
self.assertEqual(mx.vmap(fun)(x).shape, (2, 12))
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
if __name__ == "__main__":
unittest.main()