diff --git a/mlx/array.h b/mlx/array.h index 4e9a5ae63..279d70a5e 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -294,6 +294,11 @@ class array { return array_desc_->siblings; } + /** The array's position in the sibling list. */ + int sibling_position() const { + return array_desc_->position; + } + void set_siblings(std::vector siblings, uint16_t position) { array_desc_->siblings = std::move(siblings); array_desc_->position = position; diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 4ca45a56c..ac483f7b2 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -414,55 +414,58 @@ std::pair, ParentsMap> compile_dfs( const std::vector& inputs, std::vector& outputs, const std::vector& original_inputs) { - std::function recurse; std::vector tape; - std::unordered_set input_set; - std::unordered_set original_input_set; std::unordered_map>> parents_map; - for (int i = 0; i < inputs.size(); ++i) { - input_set.insert(inputs[i].id()); - original_input_set.insert(original_inputs[i].id()); - } + { + std::function recurse; + std::unordered_set input_set; + std::unordered_set original_input_set; + for (int i = 0; i < inputs.size(); ++i) { + input_set.insert(inputs[i].id()); + original_input_set.insert(original_inputs[i].id()); + } - // DFS the graph to build the tape, and log parents and scalars - std::unordered_set cache; - recurse = [&](const array& a) { - auto id = a.id(); - if (original_input_set.find(id) != original_input_set.end()) { - throw std::invalid_argument( - "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); - } - if (cache.find(id) != cache.end()) { - return; - } - for (int i = 0; i < a.inputs().size(); i++) { - auto& in = a.inputs()[i]; - parents_map[in.id()].push_back({a, i}); + // DFS the graph to build the tape, and log parents and scalars + std::unordered_set cache; + recurse = [&](const array& a) { + auto id = a.id(); + if (original_input_set.find(id) != original_input_set.end()) { + throw std::invalid_argument( + "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); + } + if (cache.find(id) != cache.end()) { + return; + } + for (int i = 0; i < a.inputs().size(); i++) { + auto& in = a.inputs()[i]; + parents_map[in.id()].push_back({a, i}); + for (auto& s : a.siblings()) { + parents_map[in.id()].push_back({s, i}); + } + // Don't recurse on inputs (but add them to the tape for the purpose + // of future optimizations) + if (input_set.find(a.id()) == input_set.end()) { + recurse(in); + } + } + cache.insert(id); for (auto& s : a.siblings()) { - parents_map[in.id()].push_back({s, i}); - } - // Don't recurse on inputs (but add them to the tape for the purpose - // of future optimizations) - if (input_set.find(a.id()) == input_set.end()) { - recurse(in); + cache.insert(s.id()); } + tape.push_back(a); + }; + for (auto& a : outputs) { + recurse(a); } - cache.insert(id); - for (auto& s : a.siblings()) { - cache.insert(s.id()); - } - tape.push_back(a); - }; - for (auto& a : outputs) { - recurse(a); } - // Deep copy the tape and parents map + // Deep copy the tape and parents map while preserving inputs and outputs std::vector new_tape; std::unordered_set io_set; std::unordered_map old_to_new; for (auto& o : outputs) { + old_to_new.insert({o.id(), o}); io_set.insert(o.id()); } for (auto& i : inputs) { @@ -483,7 +486,6 @@ std::pair, ParentsMap> compile_dfs( inputs.push_back(old_to_new.find(i.id())->second); } if (arr.siblings().size() > 0) { - // use make_arrays std::vector types; std::vector shapes; auto out = arr.outputs(); @@ -496,8 +498,7 @@ std::pair, ParentsMap> compile_dfs( for (int i = 0; i < out.size(); ++i) { old_to_new.insert({out[i].id(), as[i]}); } - // TODO maybe need to preserve position of sibling that is in tape - new_tape.push_back(as[0]); + new_tape.push_back(as[arr.sibling_position()]); } else { auto a = array( arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs)); @@ -505,7 +506,11 @@ std::pair, ParentsMap> compile_dfs( new_tape.push_back(a); } } + io_set.clear(); for (auto& o : outputs) { + if (!(io_set.insert(o.id()).second)) { + continue; + } for (auto& i : o.inputs()) { i = old_to_new.find(i.id())->second; } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index d8900d16d..26132b628 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1134,6 +1134,30 @@ class TestCompile(mlx_tests.MLXTestCase): a = fun2(mx.array(-1.0)) self.assertEqual(a.item(), 1.0) + def test_multiple_compile_same_capture(self): + def fun(do_compile): + t = mx.ones((10,)) + u = (1.0 - t) * 0.0 + t * 3.0 + + o = mx.ones((6,)) + b = o[:, None] * u + + c = b * mx.ones_like(u) + + a = mx.ones((6,)) + if do_compile: + d = mx.compile(lambda x: x @ b)(a) + e = mx.compile(lambda x: x @ c.T)(d) + else: + d = a @ b + e = d @ c.T + return e + + out = fun(True) + mx.eval(out) + expected = fun(False) + self.assertTrue(mx.allclose(out, expected)) + if __name__ == "__main__": mlx_tests.MLXTestRunner()