Limit compile buffers (#1887)

* limit compile buffers

* maybe not flaky test
This commit is contained in:
Awni Hannun
2025-02-19 20:28:13 -08:00
committed by GitHub
parent 78ba24c37d
commit c707b2b0a6
4 changed files with 43 additions and 1 deletions

View File

@@ -15,6 +15,7 @@
namespace mlx::core {
constexpr int max_compile_depth = 11;
constexpr int max_compile_arrays = 24;
bool is_unary(const Primitive& p) {
return (
@@ -570,6 +571,7 @@ void compile_fuse(
std::function<void(const array&, int, const Stream&, const Shape&)> recurse;
std::unordered_set<uintptr_t> cache;
std::unordered_set<uintptr_t> input_set;
recurse = [&](const array& a,
int depth,
const Stream& s,
@@ -587,6 +589,8 @@ void compile_fuse(
if (depth >= max_compile_depth || !a.has_primitive() ||
a.primitive().stream() != s || !is_fusable(a.primitive()) ||
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
// Possible input
input_set.insert(a.id());
return;
}
@@ -607,9 +611,20 @@ void compile_fuse(
// Arrays with a mix of parents outside the compilable section
// are not fusable
if (!all_parents_in) {
// Possible input
input_set.insert(a.id());
return;
}
if (output_map.find(a.id()) != output_map.end()) {
input_set.insert(a.id());
} else {
// Not an input anymore since fusing it
input_set.erase(a.id());
}
if (input_set.size() >= max_compile_arrays) {
return;
}
cache.insert({a.id()});
for (auto& in : a.inputs()) {
@@ -630,7 +645,7 @@ void compile_fuse(
// Recurse a second time to build the tape in the right
// order and collect the inputs
std::unordered_set<uintptr_t> input_set;
input_set.clear();
std::vector<array> inputs;
std::vector<array> fused_tape;
std::unordered_set<uintptr_t> tape_set;