mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
fix vmap for flatten (#1955)
This commit is contained in:
@@ -659,6 +659,16 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
def test_vmap_flatten(self):
|
||||
def fun(x):
|
||||
return mx.flatten(x, 0, 1)
|
||||
|
||||
x = mx.zeros((2, 3, 4))
|
||||
|
||||
self.assertEqual(mx.vmap(fun)(x).shape, (2, 12))
|
||||
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
|
||||
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user