mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
fix vmap for flatten (#1955)
This commit is contained in:
@@ -1766,13 +1766,19 @@ std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
|
||||
auto ax = axes[0];
|
||||
auto start_axis = start_axis_;
|
||||
auto end_axis = end_axis_;
|
||||
auto in = inputs[0];
|
||||
if (ax < start_axis) {
|
||||
start_axis++;
|
||||
end_axis++;
|
||||
} else if (ax <= end_axis_) {
|
||||
start_axis++;
|
||||
end_axis++;
|
||||
in = moveaxis(in, ax, 0, stream());
|
||||
ax = 0;
|
||||
} else {
|
||||
ax -= (end_axis - start_axis);
|
||||
}
|
||||
return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
|
||||
return {{flatten(in, start_axis, end_axis, stream())}, {ax}};
|
||||
}
|
||||
|
||||
bool Flatten::is_equivalent(const Primitive& other) const {
|
||||
|
Reference in New Issue
Block a user