mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix compile fusion for multi-output edge cases (#950)
* Fix compile fusion for multi-output edge cases * Add a test for multi-output compile
This commit is contained in:
parent
2427fa171e
commit
1a87dc5ea8
@ -165,25 +165,32 @@ CompileMode& compile_mode() {
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Helper like below but only merges the two provided arrays. If the src has
|
||||
// siblings then these won't be merged to the dst.
|
||||
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
||||
auto src_parents = parents_map.find(src.id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
return;
|
||||
}
|
||||
auto& pairs = parents_map[dst.id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dst;
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
};
|
||||
|
||||
// Helper that merges two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
// source to point to the destination. The arrays are assumed to be coming from
|
||||
// equivalent primitives so their siblings are merged as well.
|
||||
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dst
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
merge_one(dests[i], sources[i], parents_map);
|
||||
}
|
||||
};
|
||||
|
||||
@ -524,9 +531,14 @@ void compile_fuse(
|
||||
// - Collect inputs to the new compiled primitive
|
||||
// - Add fusable primitives to a tape in the correct order
|
||||
|
||||
std::function<void(const array&, int, const Stream&)> recurse;
|
||||
std::function<void(
|
||||
const array&, int, const Stream&, const std::vector<int>&)>
|
||||
recurse;
|
||||
std::unordered_set<uintptr_t> cache;
|
||||
recurse = [&](const array& a, int depth, const Stream& s) {
|
||||
recurse = [&](const array& a,
|
||||
int depth,
|
||||
const Stream& s,
|
||||
const std::vector<int>& shape) {
|
||||
if (cache.find(a.id()) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
@ -536,8 +548,10 @@ void compile_fuse(
|
||||
// - Constant input
|
||||
// - Stream mismatch
|
||||
// - Non fusable primitive
|
||||
// - Is global output but has a different shape
|
||||
if (depth >= max_compile_depth || !a.has_primitive() ||
|
||||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
|
||||
a.primitive().stream() != s || !is_fusable(a.primitive()) ||
|
||||
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -564,13 +578,13 @@ void compile_fuse(
|
||||
cache.insert({a.id()});
|
||||
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse(in, depth + 1, s);
|
||||
recurse(in, depth + 1, s, shape);
|
||||
}
|
||||
};
|
||||
|
||||
if (arr.has_primitive()) {
|
||||
Stream s = arr.primitive().stream();
|
||||
recurse(arr, 0, s);
|
||||
recurse(arr, 0, s, arr.shape());
|
||||
}
|
||||
|
||||
// Not worth fusing a single primitive
|
||||
@ -634,6 +648,10 @@ void compile_fuse(
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> types;
|
||||
for (auto& o : old_outputs) {
|
||||
if (o.shape() != old_outputs.back().shape()) {
|
||||
throw std::runtime_error(
|
||||
"[compile] Compilation failed. Tried to fuse operations with different output shapes");
|
||||
}
|
||||
shapes.push_back(o.shape());
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
@ -676,7 +694,7 @@ void compile_fuse(
|
||||
// - Update outputs parents to point to compiled outputs
|
||||
// - Update any overall graph outputs to be compiled outputs
|
||||
for (int o = 0; o < old_outputs.size(); ++o) {
|
||||
merge(compiled_outputs[o], old_outputs[o], parents_map);
|
||||
merge_one(compiled_outputs[o], old_outputs[o], parents_map);
|
||||
if (auto it = output_map.find(old_outputs[o].id());
|
||||
it != output_map.end()) {
|
||||
it->second = compiled_outputs[o];
|
||||
|
@ -691,6 +691,19 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))
|
||||
self.assertEqual(out.item(), 10.0)
|
||||
|
||||
def test_compile_multi_output(self):
|
||||
def fn(x):
|
||||
ys = [x]
|
||||
for i in range(5):
|
||||
ys.append(ys[-1] + x)
|
||||
return ys, mx.sum(ys[-1])
|
||||
|
||||
x = mx.ones(1, dtype=mx.int32)
|
||||
y1 = mx.compile(fn)(x)[1]
|
||||
y2 = fn(x)[1]
|
||||
self.assertEqual(y1.item(), y2.item())
|
||||
self.assertEqual(y1.item(), 6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user