mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 05:01:19 +08:00
fix vmap for flatten (#1955)
This commit is contained in:
parent
736a340478
commit
32da94507a
@ -1766,13 +1766,19 @@ std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
|
|||||||
auto ax = axes[0];
|
auto ax = axes[0];
|
||||||
auto start_axis = start_axis_;
|
auto start_axis = start_axis_;
|
||||||
auto end_axis = end_axis_;
|
auto end_axis = end_axis_;
|
||||||
|
auto in = inputs[0];
|
||||||
if (ax < start_axis) {
|
if (ax < start_axis) {
|
||||||
start_axis++;
|
start_axis++;
|
||||||
end_axis++;
|
end_axis++;
|
||||||
|
} else if (ax <= end_axis_) {
|
||||||
|
start_axis++;
|
||||||
|
end_axis++;
|
||||||
|
in = moveaxis(in, ax, 0, stream());
|
||||||
|
ax = 0;
|
||||||
} else {
|
} else {
|
||||||
ax -= (end_axis - start_axis);
|
ax -= (end_axis - start_axis);
|
||||||
}
|
}
|
||||||
return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
|
return {{flatten(in, start_axis, end_axis, stream())}, {ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Flatten::is_equivalent(const Primitive& other) const {
|
bool Flatten::is_equivalent(const Primitive& other) const {
|
||||||
|
@ -659,6 +659,16 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(mem_pre, mem_post)
|
self.assertEqual(mem_pre, mem_post)
|
||||||
|
|
||||||
|
def test_vmap_flatten(self):
|
||||||
|
def fun(x):
|
||||||
|
return mx.flatten(x, 0, 1)
|
||||||
|
|
||||||
|
x = mx.zeros((2, 3, 4))
|
||||||
|
|
||||||
|
self.assertEqual(mx.vmap(fun)(x).shape, (2, 12))
|
||||||
|
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
|
||||||
|
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user