fix vmap for flatten (#1955)

This commit is contained in:
Awni Hannun
2025-03-11 10:42:22 -07:00
committed by GitHub
parent 736a340478
commit 32da94507a
2 changed files with 17 additions and 1 deletions

View File

@@ -1766,13 +1766,19 @@ std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
auto ax = axes[0];
auto start_axis = start_axis_;
auto end_axis = end_axis_;
auto in = inputs[0];
if (ax < start_axis) {
start_axis++;
end_axis++;
} else if (ax <= end_axis_) {
start_axis++;
end_axis++;
in = moveaxis(in, ax, 0, stream());
ax = 0;
} else {
ax -= (end_axis - start_axis);
}
return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
return {{flatten(in, start_axis, end_axis, stream())}, {ax}};
}
bool Flatten::is_equivalent(const Primitive& other) const {