Limit compile buffers (#1887)

* limit compile buffers

* maybe not flaky test
This commit is contained in:
Awni Hannun 2025-02-19 20:28:13 -08:00 committed by GitHub
parent 78ba24c37d
commit c707b2b0a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 1 deletions

View File

@ -15,6 +15,7 @@
namespace mlx::core { namespace mlx::core {
constexpr int max_compile_depth = 11; constexpr int max_compile_depth = 11;
constexpr int max_compile_arrays = 24;
bool is_unary(const Primitive& p) { bool is_unary(const Primitive& p) {
return ( return (
@ -570,6 +571,7 @@ void compile_fuse(
std::function<void(const array&, int, const Stream&, const Shape&)> recurse; std::function<void(const array&, int, const Stream&, const Shape&)> recurse;
std::unordered_set<uintptr_t> cache; std::unordered_set<uintptr_t> cache;
std::unordered_set<uintptr_t> input_set;
recurse = [&](const array& a, recurse = [&](const array& a,
int depth, int depth,
const Stream& s, const Stream& s,
@ -587,6 +589,8 @@ void compile_fuse(
if (depth >= max_compile_depth || !a.has_primitive() || if (depth >= max_compile_depth || !a.has_primitive() ||
a.primitive().stream() != s || !is_fusable(a.primitive()) || a.primitive().stream() != s || !is_fusable(a.primitive()) ||
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) { (output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
// Possible input
input_set.insert(a.id());
return; return;
} }
@ -607,9 +611,20 @@ void compile_fuse(
// Arrays with a mix of parents outside the compilable section // Arrays with a mix of parents outside the compilable section
// are not fusable // are not fusable
if (!all_parents_in) { if (!all_parents_in) {
// Possible input
input_set.insert(a.id());
return; return;
} }
if (output_map.find(a.id()) != output_map.end()) {
input_set.insert(a.id());
} else {
// Not an input anymore since fusing it
input_set.erase(a.id());
}
if (input_set.size() >= max_compile_arrays) {
return;
}
cache.insert({a.id()}); cache.insert({a.id()});
for (auto& in : a.inputs()) { for (auto& in : a.inputs()) {
@ -630,7 +645,7 @@ void compile_fuse(
// Recurse a second time to build the tape in the right // Recurse a second time to build the tape in the right
// order and collect the inputs // order and collect the inputs
std::unordered_set<uintptr_t> input_set; input_set.clear();
std::vector<array> inputs; std::vector<array> inputs;
std::vector<array> fused_tape; std::vector<array> fused_tape;
std::unordered_set<uintptr_t> tape_set; std::unordered_set<uintptr_t> tape_set;

View File

@ -177,6 +177,7 @@ class TestDistributed(mlx_tests.MLXTestCase):
def test_donation(self): def test_donation(self):
x = mx.random.normal((1024,)) x = mx.random.normal((1024,))
mx.eval(x) mx.eval(x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.metal.reset_peak_memory() mx.metal.reset_peak_memory()
scale = mx.array(2.0) scale = mx.array(2.0)

View File

@ -815,6 +815,31 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(*inputs) out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20))) self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
@mx.compile
def fun(arrs):
for _ in range(6):
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
return arrs[0]
arrs = [mx.array([1.0, 2.0]) for _ in range(64)]
out = fun(arrs)
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
def test_compile_many_outputs(self):
@mx.compile
def fun(arr):
arrs = [arr] * 64
first_arrs = None
for _ in range(6):
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
if first_arrs is None:
first_arrs = arrs
return arrs[0], first_arrs
out = fun(mx.array([1.0, 2.0]))
self.assertTrue(mx.allclose(out[0], mx.array([64.0, 128.0])))
def test_shapeless_compile_matmul(self): def test_shapeless_compile_matmul(self):
a = mx.array([0.0, 1.0, 2.0]) a = mx.array([0.0, 1.0, 2.0])
b = mx.array([0.0, 1.0, 2.0]) b = mx.array([0.0, 1.0, 2.0])

View File

@ -385,6 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.eval(x) mx.eval(x)
save_file = os.path.join(self.test_dir, "donation.npy") save_file = os.path.join(self.test_dir, "donation.npy")
mx.save(save_file, x) mx.save(save_file, x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.metal.reset_peak_memory() mx.metal.reset_peak_memory()
scale = mx.array(2.0) scale = mx.array(2.0)