mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-12 07:18:52 +08:00
Speed up compile for node with many parents (#2649)
This commit is contained in:
@@ -460,6 +460,24 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
return {tape, parents_map};
|
return {tape, parents_map};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline uint64_t splitmix64(uint64_t x) noexcept {
|
||||||
|
x += 0x9e3779b97f4a7c15ull;
|
||||||
|
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ull;
|
||||||
|
x = (x ^ (x >> 27)) * 0x94d049bb133111ebull;
|
||||||
|
return x ^ (x >> 31);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct VecU64Hash {
|
||||||
|
size_t operator()(const std::vector<uint64_t>& s) const noexcept {
|
||||||
|
uint64_t h =
|
||||||
|
0x243f6a8885a308d3ull ^ (uint64_t)s.size() * 0x9e3779b97f4a7c15ull;
|
||||||
|
for (uint64_t x : s) {
|
||||||
|
h = splitmix64(x ^ splitmix64(h + 0x9e3779b97f4a7c15ull));
|
||||||
|
}
|
||||||
|
return (size_t)h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Simplify the tape. Note, this function modifies in-place both the tape,
|
// Simplify the tape. Note, this function modifies in-place both the tape,
|
||||||
// the parents map to remove orphaned arrays, and potentially the outputs
|
// the parents map to remove orphaned arrays, and potentially the outputs
|
||||||
void compile_simplify(
|
void compile_simplify(
|
||||||
@@ -584,29 +602,73 @@ void compile_simplify(
|
|||||||
if (parents != parents_map.end()) {
|
if (parents != parents_map.end()) {
|
||||||
auto N = parents->second.size();
|
auto N = parents->second.size();
|
||||||
std::vector<bool> mask(N, false);
|
std::vector<bool> mask(N, false);
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
if (mask[i]) {
|
auto try_merge = [&](int dst_idx, int src_idx) {
|
||||||
continue;
|
if (tape_order[parents->second[src_idx].first.id()] <
|
||||||
|
tape_order[parents->second[dst_idx].first.id()]) {
|
||||||
|
std::swap(src_idx, dst_idx);
|
||||||
}
|
}
|
||||||
for (int j = i + 1; j < N; j++) {
|
auto& src = parents->second[src_idx].first;
|
||||||
if (mask[j]) {
|
auto& dst = parents->second[dst_idx].first;
|
||||||
|
if (src.id() != dst.id() && array_equivalent(src, dst) &&
|
||||||
|
output_set.find(src.id()) == output_set.end()) {
|
||||||
|
merge(dst, src, parents_map);
|
||||||
|
mask[src_idx] = true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (N > 100) {
|
||||||
|
std::unordered_map<
|
||||||
|
std::vector<uint64_t>,
|
||||||
|
std::vector<int>,
|
||||||
|
VecU64Hash>
|
||||||
|
dst_map;
|
||||||
|
// Find possibly mergeable groups
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
// Make the hash key
|
||||||
|
std::vector<uint64_t> key;
|
||||||
|
auto& curr = parents->second[i].first;
|
||||||
|
key.reserve(curr.inputs().size() + 2);
|
||||||
|
for (auto& in : curr.inputs()) {
|
||||||
|
key.push_back(in.id());
|
||||||
|
}
|
||||||
|
auto& p = curr.primitive();
|
||||||
|
key.push_back(curr.inputs().size());
|
||||||
|
key.push_back(typeid(p).hash_code());
|
||||||
|
auto it = dst_map.find(key);
|
||||||
|
if (it == dst_map.end()) {
|
||||||
|
bool _;
|
||||||
|
std::tie(it, _) = dst_map.insert({key, std::vector<int>{}});
|
||||||
|
}
|
||||||
|
it->second.push_back(i);
|
||||||
|
}
|
||||||
|
for (auto& [_, group] : dst_map) {
|
||||||
|
for (int i = 0; i < group.size(); ++i) {
|
||||||
|
if (mask[group[i]]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int j = i + 1; j < group.size(); ++j) {
|
||||||
|
if (mask[group[j]]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
try_merge(group[i], group[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
if (mask[i]) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto src_idx = j;
|
for (int j = i + 1; j < N; ++j) {
|
||||||
auto dst_idx = i;
|
if (mask[j]) {
|
||||||
if (tape_order[parents->second[src_idx].first.id()] <
|
continue;
|
||||||
tape_order[parents->second[dst_idx].first.id()]) {
|
}
|
||||||
std::swap(src_idx, dst_idx);
|
try_merge(i, j);
|
||||||
}
|
|
||||||
auto& src = parents->second[src_idx].first;
|
|
||||||
auto& dst = parents->second[dst_idx].first;
|
|
||||||
if (src.id() != dst.id() && array_equivalent(src, dst) &&
|
|
||||||
output_set.find(src.id()) == output_set.end()) {
|
|
||||||
merge(dst, src, parents_map);
|
|
||||||
mask[src_idx] = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Erase orphaned parents so we don't keep fusing with them
|
// Erase orphaned parents so we don't keep fusing with them
|
||||||
for (int i = N - 1; i >= 0; --i) {
|
for (int i = N - 1; i >= 0; --i) {
|
||||||
if (mask[i]) {
|
if (mask[i]) {
|
||||||
|
|||||||
Reference in New Issue
Block a user