diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 60c13b2c9..b5e5ec82e 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1766,13 +1766,19 @@ std::pair, std::vector> Flatten::vmap( auto ax = axes[0]; auto start_axis = start_axis_; auto end_axis = end_axis_; + auto in = inputs[0]; if (ax < start_axis) { start_axis++; end_axis++; + } else if (ax <= end_axis_) { + start_axis++; + end_axis++; + in = moveaxis(in, ax, 0, stream()); + ax = 0; } else { ax -= (end_axis - start_axis); } - return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}}; + return {{flatten(in, start_axis, end_axis, stream())}, {ax}}; } bool Flatten::is_equivalent(const Primitive& other) const { diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 2eee33b5c..d1d4f0bd4 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -659,6 +659,16 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_vmap_flatten(self): + def fun(x): + return mx.flatten(x, 0, 1) + + x = mx.zeros((2, 3, 4)) + + self.assertEqual(mx.vmap(fun)(x).shape, (2, 12)) + self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8)) + self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6)) + if __name__ == "__main__": unittest.main()