mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix compile multi capture (#2678)
* fix compile when compiling multiple lambdas with the same capture * add test
This commit is contained in:
@@ -294,6 +294,11 @@ class array {
|
||||
return array_desc_->siblings;
|
||||
}
|
||||
|
||||
/** The array's position in the sibling list. */
|
||||
int sibling_position() const {
|
||||
return array_desc_->position;
|
||||
}
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
array_desc_->position = position;
|
||||
|
||||
146
mlx/compile.cpp
146
mlx/compile.cpp
@@ -412,51 +412,121 @@ compile_trace(
|
||||
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||
std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& original_inputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::vector<array> tape;
|
||||
std::unordered_set<std::uintptr_t> input_set;
|
||||
std::unordered_set<std::uintptr_t> original_input_set;
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||
parents_map;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
input_set.insert(inputs[i].id());
|
||||
original_input_set.insert(original_inputs[i].id());
|
||||
{
|
||||
std::function<void(const array&)> recurse;
|
||||
std::unordered_set<std::uintptr_t> input_set;
|
||||
std::unordered_set<std::uintptr_t> original_input_set;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
input_set.insert(inputs[i].id());
|
||||
original_input_set.insert(original_inputs[i].id());
|
||||
}
|
||||
|
||||
// DFS the graph to build the tape, and log parents and scalars
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (original_input_set.find(id) != original_input_set.end()) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Attempting to compile a function with uncaptured inputs is not allowed.");
|
||||
}
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < a.inputs().size(); i++) {
|
||||
auto& in = a.inputs()[i];
|
||||
parents_map[in.id()].push_back({a, i});
|
||||
for (auto& s : a.siblings()) {
|
||||
parents_map[in.id()].push_back({s, i});
|
||||
}
|
||||
// Don't recurse on inputs (but add them to the tape for the purpose
|
||||
// of future optimizations)
|
||||
if (input_set.find(a.id()) == input_set.end()) {
|
||||
recurse(in);
|
||||
}
|
||||
}
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
tape.push_back(a);
|
||||
};
|
||||
for (auto& a : outputs) {
|
||||
recurse(a);
|
||||
}
|
||||
}
|
||||
|
||||
// DFS the graph to build the tape, and log parents and scalars
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (original_input_set.find(id) != original_input_set.end()) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Attempting to compile a function with uncaptured inputs is not allowed.");
|
||||
}
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < a.inputs().size(); i++) {
|
||||
auto& in = a.inputs()[i];
|
||||
parents_map[in.id()].push_back({a, i});
|
||||
for (auto& s : a.siblings()) {
|
||||
parents_map[in.id()].push_back({s, i});
|
||||
}
|
||||
// Don't recurse on inputs (but add them to the tape for the purpose
|
||||
// of future optimizations)
|
||||
if (input_set.find(a.id()) == input_set.end()) {
|
||||
recurse(in);
|
||||
}
|
||||
}
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
tape.push_back(a);
|
||||
};
|
||||
for (auto& a : outputs) {
|
||||
recurse(a);
|
||||
// Deep copy the tape and parents map while preserving inputs and outputs
|
||||
std::vector<array> new_tape;
|
||||
std::unordered_set<uintptr_t> io_set;
|
||||
std::unordered_map<uintptr_t, array> old_to_new;
|
||||
for (auto& o : outputs) {
|
||||
old_to_new.insert({o.id(), o});
|
||||
io_set.insert(o.id());
|
||||
}
|
||||
for (auto& i : inputs) {
|
||||
io_set.insert(i.id());
|
||||
old_to_new.insert({i.id(), i});
|
||||
}
|
||||
|
||||
new_tape.reserve(tape.size());
|
||||
for (auto& arr : tape) {
|
||||
if (!arr.has_primitive() || (io_set.find(arr.id()) != io_set.end())) {
|
||||
old_to_new.insert({arr.id(), arr});
|
||||
new_tape.push_back(arr);
|
||||
continue;
|
||||
}
|
||||
std::vector<array> inputs;
|
||||
inputs.reserve(arr.inputs().size());
|
||||
for (auto& i : arr.inputs()) {
|
||||
inputs.push_back(old_to_new.find(i.id())->second);
|
||||
}
|
||||
if (arr.siblings().size() > 0) {
|
||||
std::vector<Dtype> types;
|
||||
std::vector<Shape> shapes;
|
||||
auto out = arr.outputs();
|
||||
for (auto& o : out) {
|
||||
types.push_back(o.dtype());
|
||||
shapes.push_back(o.shape());
|
||||
}
|
||||
auto as = array::make_arrays(
|
||||
std::move(shapes), types, arr.primitive_ptr(), std::move(inputs));
|
||||
for (int i = 0; i < out.size(); ++i) {
|
||||
old_to_new.insert({out[i].id(), as[i]});
|
||||
}
|
||||
new_tape.push_back(as[arr.sibling_position()]);
|
||||
} else {
|
||||
auto a = array(
|
||||
arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs));
|
||||
old_to_new.insert({arr.id(), a});
|
||||
new_tape.push_back(a);
|
||||
}
|
||||
}
|
||||
io_set.clear();
|
||||
for (auto& o : outputs) {
|
||||
if (!(io_set.insert(o.id()).second)) {
|
||||
continue;
|
||||
}
|
||||
for (auto& i : o.inputs()) {
|
||||
i = old_to_new.find(i.id())->second;
|
||||
}
|
||||
}
|
||||
tape = std::move(new_tape);
|
||||
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||
new_parents_map;
|
||||
for (auto& [id, vec] : parents_map) {
|
||||
for (auto& [a, _] : vec) {
|
||||
a = old_to_new.find(a.id())->second;
|
||||
}
|
||||
new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec);
|
||||
}
|
||||
parents_map = std::move(new_parents_map);
|
||||
|
||||
return {tape, parents_map};
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ using ParentsMap =
|
||||
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||
std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& original_inputs);
|
||||
|
||||
// Simplify the tape.
|
||||
|
||||
Reference in New Issue
Block a user