fix compile merging (#2150)

This commit is contained in:
Awni Hannun 2025-05-02 15:08:50 -07:00 committed by GitHub
parent 481349495b
commit 9c5e7da507
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 0 deletions

View File

@ -168,6 +168,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) {
parent.first.inputs()[parent.second] = dst; parent.first.inputs()[parent.second] = dst;
pairs.push_back(parent); pairs.push_back(parent);
} }
// If src is a parent of dst, remove it from dst's parents
for (auto it = pairs.begin(); it != pairs.end();) {
if (it->first.id() == src.id()) {
it = pairs.erase(it);
} else {
it++;
}
}
// Remove the source from the map to avoid fusing with it again // Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents); parents_map.erase(src_parents);
} }

View File

@ -795,3 +795,12 @@ TEST_CASE("test compile lambda") {
out = cfun2({array(0)}); out = cfun2({array(0)});
CHECK_EQ(out[0].item<int>(), 3); CHECK_EQ(out[0].item<int>(), 3);
} }
TEST_CASE("test compile with no-ops") {
auto fun = [](const std::vector<array>& inputs) {
return std::vector<array>{abs(stop_gradient(abs(inputs[0])))};
};
auto in = array(1.0);
auto out = compile(fun)({in})[0];
CHECK_EQ(out.inputs()[0].id(), in.id());
}