From 1a87dc5ea817af0e2a9c80e0bef46a66c984b6e9 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 2 Apr 2024 08:42:31 -0700 Subject: [PATCH] Fix compile fusion for multi-output edge cases (#950) * Fix compile fusion for multi-output edge cases * Add a test for multi-output compile --- mlx/compile.cpp | 54 ++++++++++++++++++++++++------------ python/tests/test_compile.py | 13 +++++++++ 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 28419351c..3a00a01b7 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -165,25 +165,32 @@ CompileMode& compile_mode() { using ParentsMap = std::unordered_map>>; +// Helper like below but only merges the two provided arrays. If the src has +// siblings then these won't be merged to the dst. +void merge_one(array& dst, array& src, ParentsMap& parents_map) { + auto src_parents = parents_map.find(src.id()); + if (src_parents == parents_map.end()) { + return; + } + auto& pairs = parents_map[dst.id()]; + for (auto& parent : src_parents->second) { + parent.first.inputs()[parent.second] = dst; + pairs.push_back(parent); + } + // Remove the source from the map to avoid fusing with it again + parents_map.erase(src_parents); +}; + // Helper that merges two arrays in the graph by setting the parents of the -// source to point to the destination +// source to point to the destination. The arrays are assumed to be coming from +// equivalent primitives so their siblings are merged as well. void merge(array& dst, array& src, ParentsMap& parents_map) { // Canonicalize the order of the primitives outputs auto sources = src.outputs(); auto dests = dst.outputs(); // For each src parent, point it to the corresponding dst for (int i = 0; i < sources.size(); ++i) { - auto src_parents = parents_map.find(sources[i].id()); - if (src_parents == parents_map.end()) { - continue; - } - auto& pairs = parents_map[dests[i].id()]; - for (auto& parent : src_parents->second) { - parent.first.inputs()[parent.second] = dests[i]; - pairs.push_back(parent); - } - // Remove the source from the map to avoid fusing with it again - parents_map.erase(src_parents); + merge_one(dests[i], sources[i], parents_map); } }; @@ -524,9 +531,14 @@ void compile_fuse( // - Collect inputs to the new compiled primitive // - Add fusable primitives to a tape in the correct order - std::function recurse; + std::function&)> + recurse; std::unordered_set cache; - recurse = [&](const array& a, int depth, const Stream& s) { + recurse = [&](const array& a, + int depth, + const Stream& s, + const std::vector& shape) { if (cache.find(a.id()) != cache.end()) { return; } @@ -536,8 +548,10 @@ void compile_fuse( // - Constant input // - Stream mismatch // - Non fusable primitive + // - Is global output but has a different shape if (depth >= max_compile_depth || !a.has_primitive() || - a.primitive().stream() != s || !is_fusable(a.primitive())) { + a.primitive().stream() != s || !is_fusable(a.primitive()) || + (output_map.find(a.id()) != output_map.end() && a.shape() != shape)) { return; } @@ -564,13 +578,13 @@ void compile_fuse( cache.insert({a.id()}); for (auto& in : a.inputs()) { - recurse(in, depth + 1, s); + recurse(in, depth + 1, s, shape); } }; if (arr.has_primitive()) { Stream s = arr.primitive().stream(); - recurse(arr, 0, s); + recurse(arr, 0, s, arr.shape()); } // Not worth fusing a single primitive @@ -634,6 +648,10 @@ void compile_fuse( std::vector> shapes; std::vector types; for (auto& o : old_outputs) { + if (o.shape() != old_outputs.back().shape()) { + throw std::runtime_error( + "[compile] Compilation failed. Tried to fuse operations with different output shapes"); + } shapes.push_back(o.shape()); types.push_back(o.dtype()); } @@ -676,7 +694,7 @@ void compile_fuse( // - Update outputs parents to point to compiled outputs // - Update any overall graph outputs to be compiled outputs for (int o = 0; o < old_outputs.size(); ++o) { - merge(compiled_outputs[o], old_outputs[o], parents_map); + merge_one(compiled_outputs[o], old_outputs[o], parents_map); if (auto it = output_map.find(old_outputs[o].id()); it != output_map.end()) { it->second = compiled_outputs[o]; diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index d7cce31f5..20ca62101 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -691,6 +691,19 @@ class TestCompile(mlx_tests.MLXTestCase): out = mx.compile(fn)(mx.array(10.0), mx.array(20.0)) self.assertEqual(out.item(), 10.0) + def test_compile_multi_output(self): + def fn(x): + ys = [x] + for i in range(5): + ys.append(ys[-1] + x) + return ys, mx.sum(ys[-1]) + + x = mx.ones(1, dtype=mx.int32) + y1 = mx.compile(fn)(x)[1] + y2 = fn(x)[1] + self.assertEqual(y1.item(), y2.item()) + self.assertEqual(y1.item(), 6) + if __name__ == "__main__": unittest.main()