simplify inputs also

This commit is contained in:
Awni Hannun 2024-01-16 12:53:57 -08:00
parent df1f6c221b
commit ba8ce23597

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <map>
#include <unordered_map>
#include <unordered_set>
@ -112,15 +111,16 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
const std::vector<array>& outputs) {
std::function<void(const array&)> recurse;
std::vector<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_set<std::uintptr_t> input_set;
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
parents_map;
for (int i = 0; i < inputs.size(); ++i) {
auto in = inputs[i];
cache.insert(in.id());
input_set.insert(in.id());
}
// DFS the graph to build the tape, and log parents and scalars
std::unordered_set<std::uintptr_t> cache;
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
@ -132,7 +132,11 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
for (auto& s : a.siblings()) {
parents_map[in.id()].push_back({s, i});
}
recurse(in);
// Don't recurse on inputs (but add them to the tape for the purpose
// of future optimizations)
if (input_set.find(a.id()) == input_set.end()) {
recurse(in);
}
}
cache.insert(id);
for (auto& s : a.siblings()) {
@ -277,6 +281,12 @@ void compile_simplify(
}
}
}
// Erase orphaned parents so we don't keep fusing with them
for (int i = N - 1; i > 0; --i) {
if (mask[i]) {
parents->second.erase(parents->second.begin() + i);
}
}
return false;
} else {
return output_set.find(a.id()) == output_set.end();
@ -352,7 +362,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
compile_dfs(entry.inputs, entry.outputs);
// Simplify the tape
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 2);
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
// This is a good point to do more optimizations, e.g. kernel fusion to
// generate new primitives. The tape needs to be updated accordingly