From a6d671718143a0d96715fa0d34962e2530599bdb Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 4 Dec 2025 12:32:56 -0800 Subject: [PATCH] fix compile copying (#2871) --- mlx/compile.cpp | 11 ++++++++++- python/tests/test_compile.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 4649a3708..ca5f06993 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -467,6 +467,10 @@ std::pair, ParentsMap> compile_dfs( for (auto& o : outputs) { old_to_new.insert({o.id(), o}); io_set.insert(o.id()); + for (auto& s : o.siblings()) { + old_to_new.insert({s.id(), s}); + io_set.insert(s.id()); + } } for (auto& i : inputs) { io_set.insert(i.id()); @@ -514,6 +518,12 @@ std::pair, ParentsMap> compile_dfs( for (auto& i : o.inputs()) { i = old_to_new.find(i.id())->second; } + for (auto& s : o.siblings()) { + io_set.insert(s.id()); + for (auto& i : s.inputs()) { + i = old_to_new.find(i.id())->second; + } + } } tape = std::move(new_tape); @@ -526,7 +536,6 @@ std::pair, ParentsMap> compile_dfs( new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec); } parents_map = std::move(new_parents_map); - return {tape, parents_map}; } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index bc3bf80f3..b9bff614f 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1223,6 +1223,35 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13]))) self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60]))) + def test_compile_output_with_siblings(self): + @mx.compile + def fun(x, y): + return mx.divmod(mx.abs(x), mx.abs(y))[0] + + out = fun(mx.array(1.0), mx.array(1.0)) + self.assertEqual(out.item(), 1.0) + + # Make sure the following compiles without issue + def loss_fn(params, x): + emb, w = params + return mx.fast.layer_norm(emb[x], w, None, 1e-4).sum() + + emb = mx.zeros((10, 32)) + w = mx.zeros((32,)) + + loss_and_grad_fn = mx.value_and_grad(loss_fn) + + x = mx.zeros(shape=(4, 32), dtype=mx.int32) + mx.eval(x, emb, w) + + @mx.compile + def step(emb, w, x): + loss, grads = loss_and_grad_fn((emb, w), x) + return loss, grads + + loss, grads = step(emb, w, x) + mx.eval(loss, grads) + if __name__ == "__main__": mlx_tests.MLXTestRunner()