mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compile stride bug (#812)
* fix compile stride bug * revert sdpa fix * fix cpu * fix bug with simplifying outputs
This commit is contained in:
@@ -439,7 +439,8 @@ void compile_simplify(
|
||||
}
|
||||
auto& src = parents->second[j].first;
|
||||
auto& dst = parents->second[i].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst) &&
|
||||
output_set.find(src.id()) == output_set.end()) {
|
||||
merge(dst, src, parents_map);
|
||||
mask[j] = true;
|
||||
}
|
||||
@@ -456,7 +457,6 @@ void compile_simplify(
|
||||
return output_set.find(a.id()) == output_set.end();
|
||||
}
|
||||
};
|
||||
|
||||
bool discard = maybe_merge_parents(arr);
|
||||
for (auto& s : arr.siblings()) {
|
||||
discard &= maybe_merge_parents(s);
|
||||
|
||||
Reference in New Issue
Block a user