Split broadcast so it is always fused in compile (#2318)

This commit is contained in:
Angelos Katharopoulos 2025-06-26 22:08:18 -07:00 committed by GitHub
parent 656ed7f780
commit 2c11d10f8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 3 deletions

View File

@ -245,6 +245,30 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
} }
} }
// Any parent in the divider will continue to refer to `x` but any parent not
// in the divider will refer to a copy of the operation.
array split_one(
const array& x,
ParentsMap& parents_map,
const std::unordered_set<uintptr_t>& divider) {
array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs());
auto& x_parents = parents_map[x.id()];
auto& y_parents = parents_map[y.id()];
for (auto it = x_parents.begin(); it != x_parents.end();) {
if (divider.find(it->first.id()) != divider.end()) {
it->first.inputs()[it->second] = y;
y_parents.emplace_back(std::move(*it));
it = x_parents.erase(it);
} else {
it++;
}
}
return std::move(y);
}
template <typename T, typename... U> template <typename T, typename... U>
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) { std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
using FunType = T (*)(U...); using FunType = T (*)(U...);
@ -669,10 +693,16 @@ void compile_fuse(
} }
// Arrays with a mix of parents outside the compilable section // Arrays with a mix of parents outside the compilable section
// are not fusable // are not fusable except for broadcast which we can split to avoid
// stopping fusion
if (!all_parents_in) { if (!all_parents_in) {
// Possible input if (a.has_primitive() && is_broadcast(a.primitive())) {
input_set.insert(a.id()); array b = split_one(a, parents_map, cache);
recurse(b, depth, s, shape);
} else {
// Possible input
input_set.insert(a.id());
}
return; return;
} }

View File

@ -5,6 +5,7 @@ import io
import math import math
import unittest import unittest
from functools import partial from functools import partial
from io import StringIO
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
@ -991,6 +992,28 @@ class TestCompile(mlx_tests.MLXTestCase):
y_compiled = mx.compile(fun)(x).item() y_compiled = mx.compile(fun)(x).item()
self.assertEqual(y, y_compiled) self.assertEqual(y, y_compiled)
def test_shared_broadcast(self):
def fun(x, y, z):
yy = mx.broadcast_to(y, z.shape)
return (x + yy * z), yy.sum()
a = mx.random.normal((10, 10))
b = mx.array(0.1)
c = mx.random.normal((10, 10))
mx.eval(a, b, c)
fc = mx.compile(fun)
d = fc(a, b, c)
s = StringIO()
mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1])
s.seek(0)
s = s.read()
self.assertTrue("CompiledBroadcastMultiplyAdd" in s)
d_hat = fun(a, b, c)
self.assertTrue(mx.allclose(d[0], d_hat[0]))
self.assertTrue(mx.allclose(d[1], d_hat[1]))
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()