fix compile when compiling multiple lambdas with the same capture

This commit is contained in:
Awni Hannun
2025-10-15 17:08:22 -07:00
parent e9eab527eb
commit c473719b23
3 changed files with 68 additions and 4 deletions

View File

@@ -412,7 +412,7 @@ compile_trace(
// Traverses the graph to build a tape and a map of array ids to their parents
std::pair<std::vector<array>, ParentsMap> compile_dfs(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
std::vector<array>& outputs,
const std::vector<array>& original_inputs) {
std::function<void(const array&)> recurse;
std::vector<array> tape;
@@ -457,6 +457,71 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
for (auto& a : outputs) {
recurse(a);
}
// Deep copy the tape and parents map
std::vector<array> new_tape;
std::unordered_set<uintptr_t> io_set;
std::unordered_map<uintptr_t, array> old_to_new;
for (auto& o : outputs) {
io_set.insert(o.id());
}
for (auto& i : inputs) {
io_set.insert(i.id());
old_to_new.insert({i.id(), i});
}
new_tape.reserve(tape.size());
for (auto& arr : tape) {
if (!arr.has_primitive() || (io_set.find(arr.id()) != io_set.end())) {
old_to_new.insert({arr.id(), arr});
new_tape.push_back(arr);
continue;
}
std::vector<array> inputs;
inputs.reserve(arr.inputs().size());
for (auto& i : arr.inputs()) {
inputs.push_back(old_to_new.find(i.id())->second);
}
if (arr.siblings().size() > 0) {
// use make_arrays
std::vector<Dtype> types;
std::vector<Shape> shapes;
auto out = arr.outputs();
for (auto& o : out) {
types.push_back(o.dtype());
shapes.push_back(o.shape());
}
auto as = array::make_arrays(
std::move(shapes), types, arr.primitive_ptr(), std::move(inputs));
for (int i = 0; i < out.size(); ++i) {
old_to_new.insert({out[i].id(), as[i]});
}
// TODO maybe need to preserve position of sibling that is in tape
new_tape.push_back(as[0]);
} else {
auto a = array(
arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs));
old_to_new.insert({arr.id(), a});
new_tape.push_back(a);
}
}
for (auto& o : outputs) {
for (auto& i : o.inputs()) {
i = old_to_new.find(i.id())->second;
}
}
tape = std::move(new_tape);
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
new_parents_map;
for (auto& [id, vec] : parents_map) {
for (auto& [a, _] : vec) {
a = old_to_new.find(a.id())->second;
}
new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec);
}
parents_map = std::move(new_parents_map);
return {tape, parents_map};
}

View File

@@ -47,7 +47,7 @@ using ParentsMap =
// Traverses the graph to build a tape and a map of array ids to their parents
std::pair<std::vector<array>, ParentsMap> compile_dfs(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
std::vector<array>& outputs,
const std::vector<array>& original_inputs);
// Simplify the tape.

View File

@@ -194,8 +194,7 @@ auto multi_one(const std::vector<array>&) {
auto multi_two(const std::vector<array>&) {
auto a = array(1.0);
auto b = array(1.0);
auto c = divmod(a, b);
return std::vector<array>{c};
return divmod(a, b);
}
auto multi_three(const std::vector<array>&) {