From ba8ce2359735470b3289ddf8ff05f92a8fdb64e2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 16 Jan 2024 12:53:57 -0800 Subject: [PATCH] simplify inputs also --- mlx/compile.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 124126640..0306fb229 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include @@ -112,15 +111,16 @@ std::pair, ParentsMap> compile_dfs( const std::vector& outputs) { std::function recurse; std::vector tape; - std::unordered_set cache; + std::unordered_set input_set; std::unordered_map>> parents_map; for (int i = 0; i < inputs.size(); ++i) { auto in = inputs[i]; - cache.insert(in.id()); + input_set.insert(in.id()); } // DFS the graph to build the tape, and log parents and scalars + std::unordered_set cache; recurse = [&](const array& a) { auto id = a.id(); if (cache.find(id) != cache.end()) { @@ -132,7 +132,11 @@ std::pair, ParentsMap> compile_dfs( for (auto& s : a.siblings()) { parents_map[in.id()].push_back({s, i}); } - recurse(in); + // Don't recurse on inputs (but add them to the tape for the purpose + // of future optimizations) + if (input_set.find(a.id()) == input_set.end()) { + recurse(in); + } } cache.insert(id); for (auto& s : a.siblings()) { @@ -277,6 +281,12 @@ void compile_simplify( } } } + // Erase orphaned parents so we don't keep fusing with them + for (int i = N - 1; i > 0; --i) { + if (mask[i]) { + parents->second.erase(parents->second.begin() + i); + } + } return false; } else { return output_set.find(a.id()) == output_set.end(); @@ -352,7 +362,7 @@ std::function(const std::vector&)> compile( compile_dfs(entry.inputs, entry.outputs); // Simplify the tape - compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 2); + compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3); // This is a good point to do more optimizations, e.g. kernel fusion to // generate new primitives. The tape needs to be updated accordingly