mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix compile merging (#2150)
This commit is contained in:
parent
481349495b
commit
9c5e7da507
@ -168,6 +168,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
|||||||
parent.first.inputs()[parent.second] = dst;
|
parent.first.inputs()[parent.second] = dst;
|
||||||
pairs.push_back(parent);
|
pairs.push_back(parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If src is a parent of dst, remove it from dst's parents
|
||||||
|
for (auto it = pairs.begin(); it != pairs.end();) {
|
||||||
|
if (it->first.id() == src.id()) {
|
||||||
|
it = pairs.erase(it);
|
||||||
|
} else {
|
||||||
|
it++;
|
||||||
|
}
|
||||||
|
}
|
||||||
// Remove the source from the map to avoid fusing with it again
|
// Remove the source from the map to avoid fusing with it again
|
||||||
parents_map.erase(src_parents);
|
parents_map.erase(src_parents);
|
||||||
}
|
}
|
||||||
|
@ -795,3 +795,12 @@ TEST_CASE("test compile lambda") {
|
|||||||
out = cfun2({array(0)});
|
out = cfun2({array(0)});
|
||||||
CHECK_EQ(out[0].item<int>(), 3);
|
CHECK_EQ(out[0].item<int>(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile with no-ops") {
|
||||||
|
auto fun = [](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{abs(stop_gradient(abs(inputs[0])))};
|
||||||
|
};
|
||||||
|
auto in = array(1.0);
|
||||||
|
auto out = compile(fun)({in})[0];
|
||||||
|
CHECK_EQ(out.inputs()[0].id(), in.id());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user