diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 7ed3dd455..90791b02e 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -460,6 +460,24 @@ std::pair, ParentsMap> compile_dfs( return {tape, parents_map}; } +static inline uint64_t splitmix64(uint64_t x) noexcept { + x += 0x9e3779b97f4a7c15ull; + x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ull; + x = (x ^ (x >> 27)) * 0x94d049bb133111ebull; + return x ^ (x >> 31); +} + +struct VecU64Hash { + size_t operator()(const std::vector& s) const noexcept { + uint64_t h = + 0x243f6a8885a308d3ull ^ (uint64_t)s.size() * 0x9e3779b97f4a7c15ull; + for (uint64_t x : s) { + h = splitmix64(x ^ splitmix64(h + 0x9e3779b97f4a7c15ull)); + } + return (size_t)h; + } +}; + // Simplify the tape. Note, this function modifies in-place both the tape, // the parents map to remove orphaned arrays, and potentially the outputs void compile_simplify( @@ -584,29 +602,73 @@ void compile_simplify( if (parents != parents_map.end()) { auto N = parents->second.size(); std::vector mask(N, false); - for (int i = 0; i < N; i++) { - if (mask[i]) { - continue; + + auto try_merge = [&](int dst_idx, int src_idx) { + if (tape_order[parents->second[src_idx].first.id()] < + tape_order[parents->second[dst_idx].first.id()]) { + std::swap(src_idx, dst_idx); } - for (int j = i + 1; j < N; j++) { - if (mask[j]) { + auto& src = parents->second[src_idx].first; + auto& dst = parents->second[dst_idx].first; + if (src.id() != dst.id() && array_equivalent(src, dst) && + output_set.find(src.id()) == output_set.end()) { + merge(dst, src, parents_map); + mask[src_idx] = true; + } + }; + + if (N > 100) { + std::unordered_map< + std::vector, + std::vector, + VecU64Hash> + dst_map; + // Find possibly mergeable groups + for (int i = 0; i < N; i++) { + // Make the hash key + std::vector key; + auto& curr = parents->second[i].first; + key.reserve(curr.inputs().size() + 2); + for (auto& in : curr.inputs()) { + key.push_back(in.id()); + } + auto& p = curr.primitive(); + key.push_back(curr.inputs().size()); + key.push_back(typeid(p).hash_code()); + auto it = dst_map.find(key); + if (it == dst_map.end()) { + bool _; + std::tie(it, _) = dst_map.insert({key, std::vector{}}); + } + it->second.push_back(i); + } + for (auto& [_, group] : dst_map) { + for (int i = 0; i < group.size(); ++i) { + if (mask[group[i]]) { + continue; + } + for (int j = i + 1; j < group.size(); ++j) { + if (mask[group[j]]) { + continue; + } + try_merge(group[i], group[j]); + } + } + } + } else { + for (int i = 0; i < N; ++i) { + if (mask[i]) { continue; } - auto src_idx = j; - auto dst_idx = i; - if (tape_order[parents->second[src_idx].first.id()] < - tape_order[parents->second[dst_idx].first.id()]) { - std::swap(src_idx, dst_idx); - } - auto& src = parents->second[src_idx].first; - auto& dst = parents->second[dst_idx].first; - if (src.id() != dst.id() && array_equivalent(src, dst) && - output_set.find(src.id()) == output_set.end()) { - merge(dst, src, parents_map); - mask[src_idx] = true; + for (int j = i + 1; j < N; ++j) { + if (mask[j]) { + continue; + } + try_merge(i, j); } } } + // Erase orphaned parents so we don't keep fusing with them for (int i = N - 1; i >= 0; --i) { if (mask[i]) {