mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix compile fusion for multi-output edge cases (#950)
* Fix compile fusion for multi-output edge cases * Add a test for multi-output compile
This commit is contained in:
parent
2427fa171e
commit
1a87dc5ea8
@ -165,25 +165,32 @@ CompileMode& compile_mode() {
|
|||||||
using ParentsMap =
|
using ParentsMap =
|
||||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||||
|
|
||||||
|
// Helper like below but only merges the two provided arrays. If the src has
|
||||||
|
// siblings then these won't be merged to the dst.
|
||||||
|
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
||||||
|
auto src_parents = parents_map.find(src.id());
|
||||||
|
if (src_parents == parents_map.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto& pairs = parents_map[dst.id()];
|
||||||
|
for (auto& parent : src_parents->second) {
|
||||||
|
parent.first.inputs()[parent.second] = dst;
|
||||||
|
pairs.push_back(parent);
|
||||||
|
}
|
||||||
|
// Remove the source from the map to avoid fusing with it again
|
||||||
|
parents_map.erase(src_parents);
|
||||||
|
};
|
||||||
|
|
||||||
// Helper that merges two arrays in the graph by setting the parents of the
|
// Helper that merges two arrays in the graph by setting the parents of the
|
||||||
// source to point to the destination
|
// source to point to the destination. The arrays are assumed to be coming from
|
||||||
|
// equivalent primitives so their siblings are merged as well.
|
||||||
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||||
// Canonicalize the order of the primitives outputs
|
// Canonicalize the order of the primitives outputs
|
||||||
auto sources = src.outputs();
|
auto sources = src.outputs();
|
||||||
auto dests = dst.outputs();
|
auto dests = dst.outputs();
|
||||||
// For each src parent, point it to the corresponding dst
|
// For each src parent, point it to the corresponding dst
|
||||||
for (int i = 0; i < sources.size(); ++i) {
|
for (int i = 0; i < sources.size(); ++i) {
|
||||||
auto src_parents = parents_map.find(sources[i].id());
|
merge_one(dests[i], sources[i], parents_map);
|
||||||
if (src_parents == parents_map.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& pairs = parents_map[dests[i].id()];
|
|
||||||
for (auto& parent : src_parents->second) {
|
|
||||||
parent.first.inputs()[parent.second] = dests[i];
|
|
||||||
pairs.push_back(parent);
|
|
||||||
}
|
|
||||||
// Remove the source from the map to avoid fusing with it again
|
|
||||||
parents_map.erase(src_parents);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -524,9 +531,14 @@ void compile_fuse(
|
|||||||
// - Collect inputs to the new compiled primitive
|
// - Collect inputs to the new compiled primitive
|
||||||
// - Add fusable primitives to a tape in the correct order
|
// - Add fusable primitives to a tape in the correct order
|
||||||
|
|
||||||
std::function<void(const array&, int, const Stream&)> recurse;
|
std::function<void(
|
||||||
|
const array&, int, const Stream&, const std::vector<int>&)>
|
||||||
|
recurse;
|
||||||
std::unordered_set<uintptr_t> cache;
|
std::unordered_set<uintptr_t> cache;
|
||||||
recurse = [&](const array& a, int depth, const Stream& s) {
|
recurse = [&](const array& a,
|
||||||
|
int depth,
|
||||||
|
const Stream& s,
|
||||||
|
const std::vector<int>& shape) {
|
||||||
if (cache.find(a.id()) != cache.end()) {
|
if (cache.find(a.id()) != cache.end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -536,8 +548,10 @@ void compile_fuse(
|
|||||||
// - Constant input
|
// - Constant input
|
||||||
// - Stream mismatch
|
// - Stream mismatch
|
||||||
// - Non fusable primitive
|
// - Non fusable primitive
|
||||||
|
// - Is global output but has a different shape
|
||||||
if (depth >= max_compile_depth || !a.has_primitive() ||
|
if (depth >= max_compile_depth || !a.has_primitive() ||
|
||||||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
|
a.primitive().stream() != s || !is_fusable(a.primitive()) ||
|
||||||
|
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -564,13 +578,13 @@ void compile_fuse(
|
|||||||
cache.insert({a.id()});
|
cache.insert({a.id()});
|
||||||
|
|
||||||
for (auto& in : a.inputs()) {
|
for (auto& in : a.inputs()) {
|
||||||
recurse(in, depth + 1, s);
|
recurse(in, depth + 1, s, shape);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (arr.has_primitive()) {
|
if (arr.has_primitive()) {
|
||||||
Stream s = arr.primitive().stream();
|
Stream s = arr.primitive().stream();
|
||||||
recurse(arr, 0, s);
|
recurse(arr, 0, s, arr.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not worth fusing a single primitive
|
// Not worth fusing a single primitive
|
||||||
@ -634,6 +648,10 @@ void compile_fuse(
|
|||||||
std::vector<std::vector<int>> shapes;
|
std::vector<std::vector<int>> shapes;
|
||||||
std::vector<Dtype> types;
|
std::vector<Dtype> types;
|
||||||
for (auto& o : old_outputs) {
|
for (auto& o : old_outputs) {
|
||||||
|
if (o.shape() != old_outputs.back().shape()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[compile] Compilation failed. Tried to fuse operations with different output shapes");
|
||||||
|
}
|
||||||
shapes.push_back(o.shape());
|
shapes.push_back(o.shape());
|
||||||
types.push_back(o.dtype());
|
types.push_back(o.dtype());
|
||||||
}
|
}
|
||||||
@ -676,7 +694,7 @@ void compile_fuse(
|
|||||||
// - Update outputs parents to point to compiled outputs
|
// - Update outputs parents to point to compiled outputs
|
||||||
// - Update any overall graph outputs to be compiled outputs
|
// - Update any overall graph outputs to be compiled outputs
|
||||||
for (int o = 0; o < old_outputs.size(); ++o) {
|
for (int o = 0; o < old_outputs.size(); ++o) {
|
||||||
merge(compiled_outputs[o], old_outputs[o], parents_map);
|
merge_one(compiled_outputs[o], old_outputs[o], parents_map);
|
||||||
if (auto it = output_map.find(old_outputs[o].id());
|
if (auto it = output_map.find(old_outputs[o].id());
|
||||||
it != output_map.end()) {
|
it != output_map.end()) {
|
||||||
it->second = compiled_outputs[o];
|
it->second = compiled_outputs[o];
|
||||||
|
@ -691,6 +691,19 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))
|
out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))
|
||||||
self.assertEqual(out.item(), 10.0)
|
self.assertEqual(out.item(), 10.0)
|
||||||
|
|
||||||
|
def test_compile_multi_output(self):
|
||||||
|
def fn(x):
|
||||||
|
ys = [x]
|
||||||
|
for i in range(5):
|
||||||
|
ys.append(ys[-1] + x)
|
||||||
|
return ys, mx.sum(ys[-1])
|
||||||
|
|
||||||
|
x = mx.ones(1, dtype=mx.int32)
|
||||||
|
y1 = mx.compile(fn)(x)[1]
|
||||||
|
y2 = fn(x)[1]
|
||||||
|
self.assertEqual(y1.item(), y2.item())
|
||||||
|
self.assertEqual(y1.item(), 6)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user