Speed up compile for node with many parents (#2649)

This commit is contained in:
Awni Hannun
2025-10-03 19:30:36 -07:00
committed by GitHub
parent a7a94b29d7
commit a393435d28

View File

@@ -460,6 +460,24 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
return {tape, parents_map};
}
static inline uint64_t splitmix64(uint64_t x) noexcept {
x += 0x9e3779b97f4a7c15ull;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ull;
x = (x ^ (x >> 27)) * 0x94d049bb133111ebull;
return x ^ (x >> 31);
}
struct VecU64Hash {
size_t operator()(const std::vector<uint64_t>& s) const noexcept {
uint64_t h =
0x243f6a8885a308d3ull ^ (uint64_t)s.size() * 0x9e3779b97f4a7c15ull;
for (uint64_t x : s) {
h = splitmix64(x ^ splitmix64(h + 0x9e3779b97f4a7c15ull));
}
return (size_t)h;
}
};
// Simplify the tape. Note, this function modifies in-place both the tape,
// the parents map to remove orphaned arrays, and potentially the outputs
void compile_simplify(
@@ -584,29 +602,73 @@ void compile_simplify(
if (parents != parents_map.end()) {
auto N = parents->second.size();
std::vector<bool> mask(N, false);
for (int i = 0; i < N; i++) {
if (mask[i]) {
continue;
auto try_merge = [&](int dst_idx, int src_idx) {
if (tape_order[parents->second[src_idx].first.id()] <
tape_order[parents->second[dst_idx].first.id()]) {
std::swap(src_idx, dst_idx);
}
for (int j = i + 1; j < N; j++) {
if (mask[j]) {
auto& src = parents->second[src_idx].first;
auto& dst = parents->second[dst_idx].first;
if (src.id() != dst.id() && array_equivalent(src, dst) &&
output_set.find(src.id()) == output_set.end()) {
merge(dst, src, parents_map);
mask[src_idx] = true;
}
};
if (N > 100) {
std::unordered_map<
std::vector<uint64_t>,
std::vector<int>,
VecU64Hash>
dst_map;
// Find possibly mergeable groups
for (int i = 0; i < N; i++) {
// Make the hash key
std::vector<uint64_t> key;
auto& curr = parents->second[i].first;
key.reserve(curr.inputs().size() + 2);
for (auto& in : curr.inputs()) {
key.push_back(in.id());
}
auto& p = curr.primitive();
key.push_back(curr.inputs().size());
key.push_back(typeid(p).hash_code());
auto it = dst_map.find(key);
if (it == dst_map.end()) {
bool _;
std::tie(it, _) = dst_map.insert({key, std::vector<int>{}});
}
it->second.push_back(i);
}
for (auto& [_, group] : dst_map) {
for (int i = 0; i < group.size(); ++i) {
if (mask[group[i]]) {
continue;
}
for (int j = i + 1; j < group.size(); ++j) {
if (mask[group[j]]) {
continue;
}
try_merge(group[i], group[j]);
}
}
}
} else {
for (int i = 0; i < N; ++i) {
if (mask[i]) {
continue;
}
auto src_idx = j;
auto dst_idx = i;
if (tape_order[parents->second[src_idx].first.id()] <
tape_order[parents->second[dst_idx].first.id()]) {
std::swap(src_idx, dst_idx);
}
auto& src = parents->second[src_idx].first;
auto& dst = parents->second[dst_idx].first;
if (src.id() != dst.id() && array_equivalent(src, dst) &&
output_set.find(src.id()) == output_set.end()) {
merge(dst, src, parents_map);
mask[src_idx] = true;
for (int j = i + 1; j < N; ++j) {
if (mask[j]) {
continue;
}
try_merge(i, j);
}
}
}
// Erase orphaned parents so we don't keep fusing with them
for (int i = N - 1; i >= 0; --i) {
if (mask[i]) {