mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Fix bug in tape order during simplify (#816)
* fix bug in tape order during simplify * properly fix compile * last bug
This commit is contained in:
@@ -415,6 +415,11 @@ void compile_simplify(
|
||||
}
|
||||
tape = std::move(new_tape);
|
||||
|
||||
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
|
||||
for (uint32_t i = 0; i < tape.size(); ++i) {
|
||||
tape_order.insert({tape[i].id(), i});
|
||||
}
|
||||
|
||||
std::unordered_set<uintptr_t> output_set;
|
||||
for (auto& o : outputs) {
|
||||
output_set.insert(o.id());
|
||||
@@ -437,17 +442,23 @@ void compile_simplify(
|
||||
if (mask[j]) {
|
||||
continue;
|
||||
}
|
||||
auto& src = parents->second[j].first;
|
||||
auto& dst = parents->second[i].first;
|
||||
auto src_idx = j;
|
||||
auto dst_idx = i;
|
||||
if (tape_order[parents->second[src_idx].first.id()] <
|
||||
tape_order[parents->second[dst_idx].first.id()]) {
|
||||
std::swap(src_idx, dst_idx);
|
||||
}
|
||||
auto& src = parents->second[src_idx].first;
|
||||
auto& dst = parents->second[dst_idx].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst) &&
|
||||
output_set.find(src.id()) == output_set.end()) {
|
||||
merge(dst, src, parents_map);
|
||||
mask[j] = true;
|
||||
mask[src_idx] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Erase orphaned parents so we don't keep fusing with them
|
||||
for (int i = N - 1; i > 0; --i) {
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
if (mask[i]) {
|
||||
parents->second.erase(parents->second.begin() + i);
|
||||
}
|
||||
|
Reference in New Issue
Block a user