From 93d76b0f301aa8087b3b5922226bbee4e01858de Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 3 Nov 2025 06:33:43 -0800 Subject: [PATCH] Fix compile multi capture (#2678) * fix compile when compiling multiple lambdas with the same capture * add test --- mlx/array.h | 5 ++ mlx/compile.cpp | 146 ++++++++++++++++++++++++++--------- mlx/compile_impl.h | 2 +- python/tests/test_compile.py | 24 ++++++ tests/compile_tests.cpp | 3 +- 5 files changed, 139 insertions(+), 41 deletions(-) diff --git a/mlx/array.h b/mlx/array.h index 4e9a5ae63..279d70a5e 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -294,6 +294,11 @@ class array { return array_desc_->siblings; } + /** The array's position in the sibling list. */ + int sibling_position() const { + return array_desc_->position; + } + void set_siblings(std::vector siblings, uint16_t position) { array_desc_->siblings = std::move(siblings); array_desc_->position = position; diff --git a/mlx/compile.cpp b/mlx/compile.cpp index d762c8d15..4649a3708 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -412,51 +412,121 @@ compile_trace( // Traverses the graph to build a tape and a map of array ids to their parents std::pair, ParentsMap> compile_dfs( const std::vector& inputs, - const std::vector& outputs, + std::vector& outputs, const std::vector& original_inputs) { - std::function recurse; std::vector tape; - std::unordered_set input_set; - std::unordered_set original_input_set; std::unordered_map>> parents_map; - for (int i = 0; i < inputs.size(); ++i) { - input_set.insert(inputs[i].id()); - original_input_set.insert(original_inputs[i].id()); + { + std::function recurse; + std::unordered_set input_set; + std::unordered_set original_input_set; + for (int i = 0; i < inputs.size(); ++i) { + input_set.insert(inputs[i].id()); + original_input_set.insert(original_inputs[i].id()); + } + + // DFS the graph to build the tape, and log parents and scalars + std::unordered_set cache; + recurse = [&](const array& a) { + auto id = a.id(); + if (original_input_set.find(id) != original_input_set.end()) { + throw std::invalid_argument( + "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); + } + if (cache.find(id) != cache.end()) { + return; + } + for (int i = 0; i < a.inputs().size(); i++) { + auto& in = a.inputs()[i]; + parents_map[in.id()].push_back({a, i}); + for (auto& s : a.siblings()) { + parents_map[in.id()].push_back({s, i}); + } + // Don't recurse on inputs (but add them to the tape for the purpose + // of future optimizations) + if (input_set.find(a.id()) == input_set.end()) { + recurse(in); + } + } + cache.insert(id); + for (auto& s : a.siblings()) { + cache.insert(s.id()); + } + tape.push_back(a); + }; + for (auto& a : outputs) { + recurse(a); + } } - // DFS the graph to build the tape, and log parents and scalars - std::unordered_set cache; - recurse = [&](const array& a) { - auto id = a.id(); - if (original_input_set.find(id) != original_input_set.end()) { - throw std::invalid_argument( - "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); - } - if (cache.find(id) != cache.end()) { - return; - } - for (int i = 0; i < a.inputs().size(); i++) { - auto& in = a.inputs()[i]; - parents_map[in.id()].push_back({a, i}); - for (auto& s : a.siblings()) { - parents_map[in.id()].push_back({s, i}); - } - // Don't recurse on inputs (but add them to the tape for the purpose - // of future optimizations) - if (input_set.find(a.id()) == input_set.end()) { - recurse(in); - } - } - cache.insert(id); - for (auto& s : a.siblings()) { - cache.insert(s.id()); - } - tape.push_back(a); - }; - for (auto& a : outputs) { - recurse(a); + // Deep copy the tape and parents map while preserving inputs and outputs + std::vector new_tape; + std::unordered_set io_set; + std::unordered_map old_to_new; + for (auto& o : outputs) { + old_to_new.insert({o.id(), o}); + io_set.insert(o.id()); } + for (auto& i : inputs) { + io_set.insert(i.id()); + old_to_new.insert({i.id(), i}); + } + + new_tape.reserve(tape.size()); + for (auto& arr : tape) { + if (!arr.has_primitive() || (io_set.find(arr.id()) != io_set.end())) { + old_to_new.insert({arr.id(), arr}); + new_tape.push_back(arr); + continue; + } + std::vector inputs; + inputs.reserve(arr.inputs().size()); + for (auto& i : arr.inputs()) { + inputs.push_back(old_to_new.find(i.id())->second); + } + if (arr.siblings().size() > 0) { + std::vector types; + std::vector shapes; + auto out = arr.outputs(); + for (auto& o : out) { + types.push_back(o.dtype()); + shapes.push_back(o.shape()); + } + auto as = array::make_arrays( + std::move(shapes), types, arr.primitive_ptr(), std::move(inputs)); + for (int i = 0; i < out.size(); ++i) { + old_to_new.insert({out[i].id(), as[i]}); + } + new_tape.push_back(as[arr.sibling_position()]); + } else { + auto a = array( + arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs)); + old_to_new.insert({arr.id(), a}); + new_tape.push_back(a); + } + } + io_set.clear(); + for (auto& o : outputs) { + if (!(io_set.insert(o.id()).second)) { + continue; + } + for (auto& i : o.inputs()) { + i = old_to_new.find(i.id())->second; + } + } + tape = std::move(new_tape); + + std::unordered_map>> + new_parents_map; + for (auto& [id, vec] : parents_map) { + for (auto& [a, _] : vec) { + a = old_to_new.find(a.id())->second; + } + new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec); + } + parents_map = std::move(new_parents_map); + return {tape, parents_map}; } diff --git a/mlx/compile_impl.h b/mlx/compile_impl.h index 1b48c88aa..ae8e26b92 100644 --- a/mlx/compile_impl.h +++ b/mlx/compile_impl.h @@ -47,7 +47,7 @@ using ParentsMap = // Traverses the graph to build a tape and a map of array ids to their parents std::pair, ParentsMap> compile_dfs( const std::vector& inputs, - const std::vector& outputs, + std::vector& outputs, const std::vector& original_inputs); // Simplify the tape. diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index d8900d16d..26132b628 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1134,6 +1134,30 @@ class TestCompile(mlx_tests.MLXTestCase): a = fun2(mx.array(-1.0)) self.assertEqual(a.item(), 1.0) + def test_multiple_compile_same_capture(self): + def fun(do_compile): + t = mx.ones((10,)) + u = (1.0 - t) * 0.0 + t * 3.0 + + o = mx.ones((6,)) + b = o[:, None] * u + + c = b * mx.ones_like(u) + + a = mx.ones((6,)) + if do_compile: + d = mx.compile(lambda x: x @ b)(a) + e = mx.compile(lambda x: x @ c.T)(d) + else: + d = a @ b + e = d @ c.T + return e + + out = fun(True) + mx.eval(out) + expected = fun(False) + self.assertTrue(mx.allclose(out, expected)) + if __name__ == "__main__": mlx_tests.MLXTestRunner() diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 96552ef9d..e65cfc76f 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -194,8 +194,7 @@ auto multi_one(const std::vector&) { auto multi_two(const std::vector&) { auto a = array(1.0); auto b = array(1.0); - auto c = divmod(a, b); - return std::vector{c}; + return divmod(a, b); } auto multi_three(const std::vector&) {