mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-11 15:06:42 +08:00
fix compile copying (#2871)
This commit is contained in:
@@ -467,6 +467,10 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
for (auto& o : outputs) {
|
for (auto& o : outputs) {
|
||||||
old_to_new.insert({o.id(), o});
|
old_to_new.insert({o.id(), o});
|
||||||
io_set.insert(o.id());
|
io_set.insert(o.id());
|
||||||
|
for (auto& s : o.siblings()) {
|
||||||
|
old_to_new.insert({s.id(), s});
|
||||||
|
io_set.insert(s.id());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (auto& i : inputs) {
|
for (auto& i : inputs) {
|
||||||
io_set.insert(i.id());
|
io_set.insert(i.id());
|
||||||
@@ -514,6 +518,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
for (auto& i : o.inputs()) {
|
for (auto& i : o.inputs()) {
|
||||||
i = old_to_new.find(i.id())->second;
|
i = old_to_new.find(i.id())->second;
|
||||||
}
|
}
|
||||||
|
for (auto& s : o.siblings()) {
|
||||||
|
io_set.insert(s.id());
|
||||||
|
for (auto& i : s.inputs()) {
|
||||||
|
i = old_to_new.find(i.id())->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
tape = std::move(new_tape);
|
tape = std::move(new_tape);
|
||||||
|
|
||||||
@@ -526,7 +536,6 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec);
|
new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec);
|
||||||
}
|
}
|
||||||
parents_map = std::move(new_parents_map);
|
parents_map = std::move(new_parents_map);
|
||||||
|
|
||||||
return {tape, parents_map};
|
return {tape, parents_map};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1223,6 +1223,35 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
|
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
|
||||||
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
|
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
|
||||||
|
|
||||||
|
def test_compile_output_with_siblings(self):
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, y):
|
||||||
|
return mx.divmod(mx.abs(x), mx.abs(y))[0]
|
||||||
|
|
||||||
|
out = fun(mx.array(1.0), mx.array(1.0))
|
||||||
|
self.assertEqual(out.item(), 1.0)
|
||||||
|
|
||||||
|
# Make sure the following compiles without issue
|
||||||
|
def loss_fn(params, x):
|
||||||
|
emb, w = params
|
||||||
|
return mx.fast.layer_norm(emb[x], w, None, 1e-4).sum()
|
||||||
|
|
||||||
|
emb = mx.zeros((10, 32))
|
||||||
|
w = mx.zeros((32,))
|
||||||
|
|
||||||
|
loss_and_grad_fn = mx.value_and_grad(loss_fn)
|
||||||
|
|
||||||
|
x = mx.zeros(shape=(4, 32), dtype=mx.int32)
|
||||||
|
mx.eval(x, emb, w)
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def step(emb, w, x):
|
||||||
|
loss, grads = loss_and_grad_fn((emb, w), x)
|
||||||
|
return loss, grads
|
||||||
|
|
||||||
|
loss, grads = step(emb, w, x)
|
||||||
|
mx.eval(loss, grads)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user