Fix compile fusion for multi-output edge cases (#950)

* Fix compile fusion for multi-output edge cases

* Add a test for multi-output compile
This commit is contained in:
Angelos Katharopoulos 2024-04-02 08:42:31 -07:00 committed by GitHub
parent 2427fa171e
commit 1a87dc5ea8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 18 deletions

View File

@ -165,25 +165,32 @@ CompileMode& compile_mode() {
using ParentsMap = using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>; std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Helper like below but only merges the two provided arrays. If the src has
// siblings then these won't be merged to the dst.
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
auto src_parents = parents_map.find(src.id());
if (src_parents == parents_map.end()) {
return;
}
auto& pairs = parents_map[dst.id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dst;
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
};
// Helper that merges two arrays in the graph by setting the parents of the // Helper that merges two arrays in the graph by setting the parents of the
// source to point to the destination // source to point to the destination. The arrays are assumed to be coming from
// equivalent primitives so their siblings are merged as well.
void merge(array& dst, array& src, ParentsMap& parents_map) { void merge(array& dst, array& src, ParentsMap& parents_map) {
// Canonicalize the order of the primitives outputs // Canonicalize the order of the primitives outputs
auto sources = src.outputs(); auto sources = src.outputs();
auto dests = dst.outputs(); auto dests = dst.outputs();
// For each src parent, point it to the corresponding dst // For each src parent, point it to the corresponding dst
for (int i = 0; i < sources.size(); ++i) { for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id()); merge_one(dests[i], sources[i], parents_map);
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
} }
}; };
@ -524,9 +531,14 @@ void compile_fuse(
// - Collect inputs to the new compiled primitive // - Collect inputs to the new compiled primitive
// - Add fusable primitives to a tape in the correct order // - Add fusable primitives to a tape in the correct order
std::function<void(const array&, int, const Stream&)> recurse; std::function<void(
const array&, int, const Stream&, const std::vector<int>&)>
recurse;
std::unordered_set<uintptr_t> cache; std::unordered_set<uintptr_t> cache;
recurse = [&](const array& a, int depth, const Stream& s) { recurse = [&](const array& a,
int depth,
const Stream& s,
const std::vector<int>& shape) {
if (cache.find(a.id()) != cache.end()) { if (cache.find(a.id()) != cache.end()) {
return; return;
} }
@ -536,8 +548,10 @@ void compile_fuse(
// - Constant input // - Constant input
// - Stream mismatch // - Stream mismatch
// - Non fusable primitive // - Non fusable primitive
// - Is global output but has a different shape
if (depth >= max_compile_depth || !a.has_primitive() || if (depth >= max_compile_depth || !a.has_primitive() ||
a.primitive().stream() != s || !is_fusable(a.primitive())) { a.primitive().stream() != s || !is_fusable(a.primitive()) ||
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
return; return;
} }
@ -564,13 +578,13 @@ void compile_fuse(
cache.insert({a.id()}); cache.insert({a.id()});
for (auto& in : a.inputs()) { for (auto& in : a.inputs()) {
recurse(in, depth + 1, s); recurse(in, depth + 1, s, shape);
} }
}; };
if (arr.has_primitive()) { if (arr.has_primitive()) {
Stream s = arr.primitive().stream(); Stream s = arr.primitive().stream();
recurse(arr, 0, s); recurse(arr, 0, s, arr.shape());
} }
// Not worth fusing a single primitive // Not worth fusing a single primitive
@ -634,6 +648,10 @@ void compile_fuse(
std::vector<std::vector<int>> shapes; std::vector<std::vector<int>> shapes;
std::vector<Dtype> types; std::vector<Dtype> types;
for (auto& o : old_outputs) { for (auto& o : old_outputs) {
if (o.shape() != old_outputs.back().shape()) {
throw std::runtime_error(
"[compile] Compilation failed. Tried to fuse operations with different output shapes");
}
shapes.push_back(o.shape()); shapes.push_back(o.shape());
types.push_back(o.dtype()); types.push_back(o.dtype());
} }
@ -676,7 +694,7 @@ void compile_fuse(
// - Update outputs parents to point to compiled outputs // - Update outputs parents to point to compiled outputs
// - Update any overall graph outputs to be compiled outputs // - Update any overall graph outputs to be compiled outputs
for (int o = 0; o < old_outputs.size(); ++o) { for (int o = 0; o < old_outputs.size(); ++o) {
merge(compiled_outputs[o], old_outputs[o], parents_map); merge_one(compiled_outputs[o], old_outputs[o], parents_map);
if (auto it = output_map.find(old_outputs[o].id()); if (auto it = output_map.find(old_outputs[o].id());
it != output_map.end()) { it != output_map.end()) {
it->second = compiled_outputs[o]; it->second = compiled_outputs[o];

View File

@ -691,6 +691,19 @@ class TestCompile(mlx_tests.MLXTestCase):
out = mx.compile(fn)(mx.array(10.0), mx.array(20.0)) out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))
self.assertEqual(out.item(), 10.0) self.assertEqual(out.item(), 10.0)
def test_compile_multi_output(self):
def fn(x):
ys = [x]
for i in range(5):
ys.append(ys[-1] + x)
return ys, mx.sum(ys[-1])
x = mx.ones(1, dtype=mx.int32)
y1 = mx.compile(fn)(x)[1]
y2 = fn(x)[1]
self.assertEqual(y1.item(), y2.item())
self.assertEqual(y1.item(), 6)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()