Compile stride bug (#812)

* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
This commit is contained in:
Awni Hannun
2024-03-11 06:31:31 -07:00
committed by GitHub
parent a4d290adb9
commit 7c441600fe
9 changed files with 58 additions and 12 deletions

View File

@@ -439,7 +439,8 @@ void compile_simplify(
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
if (src.id() != dst.id() && array_equivalent(src, dst) &&
output_set.find(src.id()) == output_set.end()) {
merge(dst, src, parents_map);
mask[j] = true;
}
@@ -456,7 +457,6 @@ void compile_simplify(
return output_set.find(a.id()) == output_set.end();
}
};
bool discard = maybe_merge_parents(arr);
for (auto& s : arr.siblings()) {
discard &= maybe_merge_parents(s);