diff --git a/mlx/compile.cpp b/mlx/compile.cpp index ced958b13..7ff5c8f9e 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -15,6 +15,7 @@ namespace mlx::core { constexpr int max_compile_depth = 11; +constexpr int max_compile_arrays = 24; bool is_unary(const Primitive& p) { return ( @@ -570,6 +571,7 @@ void compile_fuse( std::function recurse; std::unordered_set cache; + std::unordered_set input_set; recurse = [&](const array& a, int depth, const Stream& s, @@ -587,6 +589,8 @@ void compile_fuse( if (depth >= max_compile_depth || !a.has_primitive() || a.primitive().stream() != s || !is_fusable(a.primitive()) || (output_map.find(a.id()) != output_map.end() && a.shape() != shape)) { + // Possible input + input_set.insert(a.id()); return; } @@ -607,9 +611,20 @@ void compile_fuse( // Arrays with a mix of parents outside the compilable section // are not fusable if (!all_parents_in) { + // Possible input + input_set.insert(a.id()); return; } + if (output_map.find(a.id()) != output_map.end()) { + input_set.insert(a.id()); + } else { + // Not an input anymore since fusing it + input_set.erase(a.id()); + } + if (input_set.size() >= max_compile_arrays) { + return; + } cache.insert({a.id()}); for (auto& in : a.inputs()) { @@ -630,7 +645,7 @@ void compile_fuse( // Recurse a second time to build the tape in the right // order and collect the inputs - std::unordered_set input_set; + input_set.clear(); std::vector inputs; std::vector fused_tape; std::unordered_set tape_set; diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index d63033429..097eb7dc3 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -177,6 +177,7 @@ class TestDistributed(mlx_tests.MLXTestCase): def test_donation(self): x = mx.random.normal((1024,)) mx.eval(x) + mx.synchronize(mx.default_stream(mx.default_device())) mx.metal.reset_peak_memory() scale = mx.array(2.0) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index ba6b316ce..ed283317e 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -815,6 +815,31 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun(*inputs) self.assertTrue(mx.allclose(out, mx.full((2, 2), 20))) + @mx.compile + def fun(arrs): + for _ in range(6): + arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])] + return arrs[0] + + arrs = [mx.array([1.0, 2.0]) for _ in range(64)] + out = fun(arrs) + self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0]))) + + def test_compile_many_outputs(self): + + @mx.compile + def fun(arr): + arrs = [arr] * 64 + first_arrs = None + for _ in range(6): + arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])] + if first_arrs is None: + first_arrs = arrs + return arrs[0], first_arrs + + out = fun(mx.array([1.0, 2.0])) + self.assertTrue(mx.allclose(out[0], mx.array([64.0, 128.0]))) + def test_shapeless_compile_matmul(self): a = mx.array([0.0, 1.0, 2.0]) b = mx.array([0.0, 1.0, 2.0]) diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 335d6ea94..fbc67f3c2 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -385,6 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase): mx.eval(x) save_file = os.path.join(self.test_dir, "donation.npy") mx.save(save_file, x) + mx.synchronize(mx.default_stream(mx.default_device())) mx.metal.reset_peak_memory() scale = mx.array(2.0)