mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-09 13:25:32 +08:00
fix compile copying (#2871)
This commit is contained in:
@@ -467,6 +467,10 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
for (auto& o : outputs) {
|
||||
old_to_new.insert({o.id(), o});
|
||||
io_set.insert(o.id());
|
||||
for (auto& s : o.siblings()) {
|
||||
old_to_new.insert({s.id(), s});
|
||||
io_set.insert(s.id());
|
||||
}
|
||||
}
|
||||
for (auto& i : inputs) {
|
||||
io_set.insert(i.id());
|
||||
@@ -514,6 +518,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
for (auto& i : o.inputs()) {
|
||||
i = old_to_new.find(i.id())->second;
|
||||
}
|
||||
for (auto& s : o.siblings()) {
|
||||
io_set.insert(s.id());
|
||||
for (auto& i : s.inputs()) {
|
||||
i = old_to_new.find(i.id())->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
tape = std::move(new_tape);
|
||||
|
||||
@@ -526,7 +536,6 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec);
|
||||
}
|
||||
parents_map = std::move(new_parents_map);
|
||||
|
||||
return {tape, parents_map};
|
||||
}
|
||||
|
||||
|
||||
@@ -1223,6 +1223,35 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
|
||||
|
||||
def test_compile_output_with_siblings(self):
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
return mx.divmod(mx.abs(x), mx.abs(y))[0]
|
||||
|
||||
out = fun(mx.array(1.0), mx.array(1.0))
|
||||
self.assertEqual(out.item(), 1.0)
|
||||
|
||||
# Make sure the following compiles without issue
|
||||
def loss_fn(params, x):
|
||||
emb, w = params
|
||||
return mx.fast.layer_norm(emb[x], w, None, 1e-4).sum()
|
||||
|
||||
emb = mx.zeros((10, 32))
|
||||
w = mx.zeros((32,))
|
||||
|
||||
loss_and_grad_fn = mx.value_and_grad(loss_fn)
|
||||
|
||||
x = mx.zeros(shape=(4, 32), dtype=mx.int32)
|
||||
mx.eval(x, emb, w)
|
||||
|
||||
@mx.compile
|
||||
def step(emb, w, x):
|
||||
loss, grads = loss_and_grad_fn((emb, w), x)
|
||||
return loss, grads
|
||||
|
||||
loss, grads = step(emb, w, x)
|
||||
mx.eval(loss, grads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user