fix compile copying (#2871)

This commit is contained in:
Awni Hannun
2025-12-04 12:32:56 -08:00
committed by GitHub
parent 941cfe23d7
commit a6d6717181
2 changed files with 39 additions and 1 deletions

View File

@@ -467,6 +467,10 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
for (auto& o : outputs) {
old_to_new.insert({o.id(), o});
io_set.insert(o.id());
for (auto& s : o.siblings()) {
old_to_new.insert({s.id(), s});
io_set.insert(s.id());
}
}
for (auto& i : inputs) {
io_set.insert(i.id());
@@ -514,6 +518,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
for (auto& i : o.inputs()) {
i = old_to_new.find(i.id())->second;
}
for (auto& s : o.siblings()) {
io_set.insert(s.id());
for (auto& i : s.inputs()) {
i = old_to_new.find(i.id())->second;
}
}
}
tape = std::move(new_tape);
@@ -526,7 +536,6 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec);
}
parents_map = std::move(new_parents_map);
return {tape, parents_map};
}

View File

@@ -1223,6 +1223,35 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
def test_compile_output_with_siblings(self):
@mx.compile
def fun(x, y):
return mx.divmod(mx.abs(x), mx.abs(y))[0]
out = fun(mx.array(1.0), mx.array(1.0))
self.assertEqual(out.item(), 1.0)
# Make sure the following compiles without issue
def loss_fn(params, x):
emb, w = params
return mx.fast.layer_norm(emb[x], w, None, 1e-4).sum()
emb = mx.zeros((10, 32))
w = mx.zeros((32,))
loss_and_grad_fn = mx.value_and_grad(loss_fn)
x = mx.zeros(shape=(4, 32), dtype=mx.int32)
mx.eval(x, emb, w)
@mx.compile
def step(emb, w, x):
loss, grads = loss_and_grad_fn((emb, w), x)
return loss, grads
loss, grads = step(emb, w, x)
mx.eval(loss, grads)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()