diff --git a/mlx/compile.cpp b/mlx/compile.cpp index ab3be9b070..17aa6a5237 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include #include @@ -252,7 +251,6 @@ void compile_simplify( for (auto& arr : tape) { // Helper to check if we can fuse the parents of the // given array - // If an array has no parents and siblings have auto maybe_fuse_parents = [&](auto& a) { auto parents = parents_map.find(a.id()); if (parents != parents_map.end()) { @@ -276,7 +274,7 @@ void compile_simplify( } return false; } else { - return output_set.find(a.id()) != output_set.end(); + return output_set.find(a.id()) == output_set.end(); } }; @@ -349,8 +347,7 @@ std::function(const std::vector&)> compile( compile_dfs(entry.inputs, entry.outputs); // Simplify the tape - // compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ - // 2); + compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 2); // This is a good point to do more optimizations, e.g. kernel fusion to // generate new primitives. The tape needs to be updated accordingly