From c473719b23d9091292b10f0cd4f3636300831a58 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 15 Oct 2025 17:08:22 -0700 Subject: [PATCH] fix compile when compiling multiple lambdas with the same capture --- mlx/compile.cpp | 67 ++++++++++++++++++++++++++++++++++++++++- mlx/compile_impl.h | 2 +- tests/compile_tests.cpp | 3 +- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 90791b02e..4ca45a56c 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -412,7 +412,7 @@ compile_trace( // Traverses the graph to build a tape and a map of array ids to their parents std::pair, ParentsMap> compile_dfs( const std::vector& inputs, - const std::vector& outputs, + std::vector& outputs, const std::vector& original_inputs) { std::function recurse; std::vector tape; @@ -457,6 +457,71 @@ std::pair, ParentsMap> compile_dfs( for (auto& a : outputs) { recurse(a); } + + // Deep copy the tape and parents map + std::vector new_tape; + std::unordered_set io_set; + std::unordered_map old_to_new; + for (auto& o : outputs) { + io_set.insert(o.id()); + } + for (auto& i : inputs) { + io_set.insert(i.id()); + old_to_new.insert({i.id(), i}); + } + + new_tape.reserve(tape.size()); + for (auto& arr : tape) { + if (!arr.has_primitive() || (io_set.find(arr.id()) != io_set.end())) { + old_to_new.insert({arr.id(), arr}); + new_tape.push_back(arr); + continue; + } + std::vector inputs; + inputs.reserve(arr.inputs().size()); + for (auto& i : arr.inputs()) { + inputs.push_back(old_to_new.find(i.id())->second); + } + if (arr.siblings().size() > 0) { + // use make_arrays + std::vector types; + std::vector shapes; + auto out = arr.outputs(); + for (auto& o : out) { + types.push_back(o.dtype()); + shapes.push_back(o.shape()); + } + auto as = array::make_arrays( + std::move(shapes), types, arr.primitive_ptr(), std::move(inputs)); + for (int i = 0; i < out.size(); ++i) { + old_to_new.insert({out[i].id(), as[i]}); + } + // TODO maybe need to preserve position of sibling that is in tape + new_tape.push_back(as[0]); + } else { + auto a = array( + arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs)); + old_to_new.insert({arr.id(), a}); + new_tape.push_back(a); + } + } + for (auto& o : outputs) { + for (auto& i : o.inputs()) { + i = old_to_new.find(i.id())->second; + } + } + tape = std::move(new_tape); + + std::unordered_map>> + new_parents_map; + for (auto& [id, vec] : parents_map) { + for (auto& [a, _] : vec) { + a = old_to_new.find(a.id())->second; + } + new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec); + } + parents_map = std::move(new_parents_map); + return {tape, parents_map}; } diff --git a/mlx/compile_impl.h b/mlx/compile_impl.h index 1b48c88aa..ae8e26b92 100644 --- a/mlx/compile_impl.h +++ b/mlx/compile_impl.h @@ -47,7 +47,7 @@ using ParentsMap = // Traverses the graph to build a tape and a map of array ids to their parents std::pair, ParentsMap> compile_dfs( const std::vector& inputs, - const std::vector& outputs, + std::vector& outputs, const std::vector& original_inputs); // Simplify the tape. diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 96552ef9d..e65cfc76f 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -194,8 +194,7 @@ auto multi_one(const std::vector&) { auto multi_two(const std::vector&) { auto a = array(1.0); auto b = array(1.0); - auto c = divmod(a, b); - return std::vector{c}; + return divmod(a, b); } auto multi_three(const std::vector&) {