From 2c11d10f8d8e4d124cb447af731b9199374695bb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Jun 2025 22:08:18 -0700 Subject: [PATCH] Split broadcast so it is always fused in compile (#2318) --- mlx/compile.cpp | 36 +++++++++++++++++++++++++++++++++--- python/tests/test_compile.py | 23 +++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 79a55ba8f..0cb3b5a85 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -245,6 +245,30 @@ void merge(array& dst, array& src, ParentsMap& parents_map) { } } +// Any parent in the divider will continue to refer to `x` but any parent not +// in the divider will refer to a copy of the operation. +array split_one( + const array& x, + ParentsMap& parents_map, + const std::unordered_set& divider) { + array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs()); + + auto& x_parents = parents_map[x.id()]; + auto& y_parents = parents_map[y.id()]; + + for (auto it = x_parents.begin(); it != x_parents.end();) { + if (divider.find(it->first.id()) != divider.end()) { + it->first.inputs()[it->second] = y; + y_parents.emplace_back(std::move(*it)); + it = x_parents.erase(it); + } else { + it++; + } + } + + return std::move(y); +} + template std::uintptr_t get_function_address(const std::function& fun) { using FunType = T (*)(U...); @@ -669,10 +693,16 @@ void compile_fuse( } // Arrays with a mix of parents outside the compilable section - // are not fusable + // are not fusable except for broadcast which we can split to avoid + // stopping fusion if (!all_parents_in) { - // Possible input - input_set.insert(a.id()); + if (a.has_primitive() && is_broadcast(a.primitive())) { + array b = split_one(a, parents_map, cache); + recurse(b, depth, s, shape); + } else { + // Possible input + input_set.insert(a.id()); + } return; } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index ca33c2d3a..ada2b1484 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -5,6 +5,7 @@ import io import math import unittest from functools import partial +from io import StringIO import mlx.core as mx import mlx_tests @@ -991,6 +992,28 @@ class TestCompile(mlx_tests.MLXTestCase): y_compiled = mx.compile(fun)(x).item() self.assertEqual(y, y_compiled) + def test_shared_broadcast(self): + def fun(x, y, z): + yy = mx.broadcast_to(y, z.shape) + return (x + yy * z), yy.sum() + + a = mx.random.normal((10, 10)) + b = mx.array(0.1) + c = mx.random.normal((10, 10)) + mx.eval(a, b, c) + fc = mx.compile(fun) + d = fc(a, b, c) + + s = StringIO() + mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1]) + s.seek(0) + s = s.read() + + self.assertTrue("CompiledBroadcastMultiplyAdd" in s) + d_hat = fun(a, b, c) + self.assertTrue(mx.allclose(d[0], d_hat[0])) + self.assertTrue(mx.allclose(d[1], d_hat[1])) + if __name__ == "__main__": mlx_tests.MLXTestRunner()