This commit is contained in:
Awni Hannun 2024-01-15 21:37:31 -08:00
parent b75ff47098
commit d34d10ebac

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
@ -252,7 +251,6 @@ void compile_simplify(
for (auto& arr : tape) { for (auto& arr : tape) {
// Helper to check if we can fuse the parents of the // Helper to check if we can fuse the parents of the
// given array // given array
// If an array has no parents and siblings have
auto maybe_fuse_parents = [&](auto& a) { auto maybe_fuse_parents = [&](auto& a) {
auto parents = parents_map.find(a.id()); auto parents = parents_map.find(a.id());
if (parents != parents_map.end()) { if (parents != parents_map.end()) {
@ -276,7 +274,7 @@ void compile_simplify(
} }
return false; return false;
} else { } else {
return output_set.find(a.id()) != output_set.end(); return output_set.find(a.id()) == output_set.end();
} }
}; };
@ -349,8 +347,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
compile_dfs(entry.inputs, entry.outputs); compile_dfs(entry.inputs, entry.outputs);
// Simplify the tape // Simplify the tape
// compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 2);
// 2);
// This is a good point to do more optimizations, e.g. kernel fusion to // This is a good point to do more optimizations, e.g. kernel fusion to
// generate new primitives. The tape needs to be updated accordingly // generate new primitives. The tape needs to be updated accordingly