Split broadcast so it is always fused in compile (#2318)

This commit is contained in:
Angelos Katharopoulos
2025-06-26 22:08:18 -07:00
committed by GitHub
parent 656ed7f780
commit 2c11d10f8d
2 changed files with 56 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ import io
import math
import unittest
from functools import partial
from io import StringIO
import mlx.core as mx
import mlx_tests
@@ -991,6 +992,28 @@ class TestCompile(mlx_tests.MLXTestCase):
y_compiled = mx.compile(fun)(x).item()
self.assertEqual(y, y_compiled)
def test_shared_broadcast(self):
def fun(x, y, z):
yy = mx.broadcast_to(y, z.shape)
return (x + yy * z), yy.sum()
a = mx.random.normal((10, 10))
b = mx.array(0.1)
c = mx.random.normal((10, 10))
mx.eval(a, b, c)
fc = mx.compile(fun)
d = fc(a, b, c)
s = StringIO()
mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1])
s.seek(0)
s = s.read()
self.assertTrue("CompiledBroadcastMultiplyAdd" in s)
d_hat = fun(a, b, c)
self.assertTrue(mx.allclose(d[0], d_hat[0]))
self.assertTrue(mx.allclose(d[1], d_hat[1]))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()