mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 19:11:17 +08:00
fix vmap for flatten (#1955)
This commit is contained in:
parent
736a340478
commit
32da94507a
@ -1766,13 +1766,19 @@ std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
|
||||
auto ax = axes[0];
|
||||
auto start_axis = start_axis_;
|
||||
auto end_axis = end_axis_;
|
||||
auto in = inputs[0];
|
||||
if (ax < start_axis) {
|
||||
start_axis++;
|
||||
end_axis++;
|
||||
} else if (ax <= end_axis_) {
|
||||
start_axis++;
|
||||
end_axis++;
|
||||
in = moveaxis(in, ax, 0, stream());
|
||||
ax = 0;
|
||||
} else {
|
||||
ax -= (end_axis - start_axis);
|
||||
}
|
||||
return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
|
||||
return {{flatten(in, start_axis, end_axis, stream())}, {ax}};
|
||||
}
|
||||
|
||||
bool Flatten::is_equivalent(const Primitive& other) const {
|
||||
|
@ -659,6 +659,16 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
def test_vmap_flatten(self):
|
||||
def fun(x):
|
||||
return mx.flatten(x, 0, 1)
|
||||
|
||||
x = mx.zeros((2, 3, 4))
|
||||
|
||||
self.assertEqual(mx.vmap(fun)(x).shape, (2, 12))
|
||||
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
|
||||
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user