fix vmap for flatten (#1955)

This commit is contained in:
Awni Hannun 2025-03-11 10:42:22 -07:00 committed by GitHub
parent 736a340478
commit 32da94507a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 1 deletions

View File

@ -1766,13 +1766,19 @@ std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
auto ax = axes[0];
auto start_axis = start_axis_;
auto end_axis = end_axis_;
auto in = inputs[0];
if (ax < start_axis) {
start_axis++;
end_axis++;
} else if (ax <= end_axis_) {
start_axis++;
end_axis++;
in = moveaxis(in, ax, 0, stream());
ax = 0;
} else {
ax -= (end_axis - start_axis);
}
return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
return {{flatten(in, start_axis, end_axis, stream())}, {ax}};
}
bool Flatten::is_equivalent(const Primitive& other) const {

View File

@ -659,6 +659,16 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(mem_pre, mem_post)
def test_vmap_flatten(self):
def fun(x):
return mx.flatten(x, 0, 1)
x = mx.zeros((2, 3, 4))
self.assertEqual(mx.vmap(fun)(x).shape, (2, 12))
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
if __name__ == "__main__":
unittest.main()