mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-10 05:59:04 +08:00
Speed up compile for node with many parents (#2649)
This commit is contained in:
@@ -460,6 +460,24 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
return {tape, parents_map};
|
||||
}
|
||||
|
||||
static inline uint64_t splitmix64(uint64_t x) noexcept {
|
||||
x += 0x9e3779b97f4a7c15ull;
|
||||
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ull;
|
||||
x = (x ^ (x >> 27)) * 0x94d049bb133111ebull;
|
||||
return x ^ (x >> 31);
|
||||
}
|
||||
|
||||
struct VecU64Hash {
|
||||
size_t operator()(const std::vector<uint64_t>& s) const noexcept {
|
||||
uint64_t h =
|
||||
0x243f6a8885a308d3ull ^ (uint64_t)s.size() * 0x9e3779b97f4a7c15ull;
|
||||
for (uint64_t x : s) {
|
||||
h = splitmix64(x ^ splitmix64(h + 0x9e3779b97f4a7c15ull));
|
||||
}
|
||||
return (size_t)h;
|
||||
}
|
||||
};
|
||||
|
||||
// Simplify the tape. Note, this function modifies in-place both the tape,
|
||||
// the parents map to remove orphaned arrays, and potentially the outputs
|
||||
void compile_simplify(
|
||||
@@ -584,29 +602,73 @@ void compile_simplify(
|
||||
if (parents != parents_map.end()) {
|
||||
auto N = parents->second.size();
|
||||
std::vector<bool> mask(N, false);
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (mask[i]) {
|
||||
continue;
|
||||
|
||||
auto try_merge = [&](int dst_idx, int src_idx) {
|
||||
if (tape_order[parents->second[src_idx].first.id()] <
|
||||
tape_order[parents->second[dst_idx].first.id()]) {
|
||||
std::swap(src_idx, dst_idx);
|
||||
}
|
||||
for (int j = i + 1; j < N; j++) {
|
||||
if (mask[j]) {
|
||||
auto& src = parents->second[src_idx].first;
|
||||
auto& dst = parents->second[dst_idx].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst) &&
|
||||
output_set.find(src.id()) == output_set.end()) {
|
||||
merge(dst, src, parents_map);
|
||||
mask[src_idx] = true;
|
||||
}
|
||||
};
|
||||
|
||||
if (N > 100) {
|
||||
std::unordered_map<
|
||||
std::vector<uint64_t>,
|
||||
std::vector<int>,
|
||||
VecU64Hash>
|
||||
dst_map;
|
||||
// Find possibly mergeable groups
|
||||
for (int i = 0; i < N; i++) {
|
||||
// Make the hash key
|
||||
std::vector<uint64_t> key;
|
||||
auto& curr = parents->second[i].first;
|
||||
key.reserve(curr.inputs().size() + 2);
|
||||
for (auto& in : curr.inputs()) {
|
||||
key.push_back(in.id());
|
||||
}
|
||||
auto& p = curr.primitive();
|
||||
key.push_back(curr.inputs().size());
|
||||
key.push_back(typeid(p).hash_code());
|
||||
auto it = dst_map.find(key);
|
||||
if (it == dst_map.end()) {
|
||||
bool _;
|
||||
std::tie(it, _) = dst_map.insert({key, std::vector<int>{}});
|
||||
}
|
||||
it->second.push_back(i);
|
||||
}
|
||||
for (auto& [_, group] : dst_map) {
|
||||
for (int i = 0; i < group.size(); ++i) {
|
||||
if (mask[group[i]]) {
|
||||
continue;
|
||||
}
|
||||
for (int j = i + 1; j < group.size(); ++j) {
|
||||
if (mask[group[j]]) {
|
||||
continue;
|
||||
}
|
||||
try_merge(group[i], group[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (mask[i]) {
|
||||
continue;
|
||||
}
|
||||
auto src_idx = j;
|
||||
auto dst_idx = i;
|
||||
if (tape_order[parents->second[src_idx].first.id()] <
|
||||
tape_order[parents->second[dst_idx].first.id()]) {
|
||||
std::swap(src_idx, dst_idx);
|
||||
}
|
||||
auto& src = parents->second[src_idx].first;
|
||||
auto& dst = parents->second[dst_idx].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst) &&
|
||||
output_set.find(src.id()) == output_set.end()) {
|
||||
merge(dst, src, parents_map);
|
||||
mask[src_idx] = true;
|
||||
for (int j = i + 1; j < N; ++j) {
|
||||
if (mask[j]) {
|
||||
continue;
|
||||
}
|
||||
try_merge(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Erase orphaned parents so we don't keep fusing with them
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
if (mask[i]) {
|
||||
|
||||
Reference in New Issue
Block a user