mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
Split broadcast so it is always fused in compile (#2318)
This commit is contained in:

committed by
GitHub

parent
656ed7f780
commit
2c11d10f8d
@@ -5,6 +5,7 @@ import io
|
||||
import math
|
||||
import unittest
|
||||
from functools import partial
|
||||
from io import StringIO
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@@ -991,6 +992,28 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
y_compiled = mx.compile(fun)(x).item()
|
||||
self.assertEqual(y, y_compiled)
|
||||
|
||||
def test_shared_broadcast(self):
|
||||
def fun(x, y, z):
|
||||
yy = mx.broadcast_to(y, z.shape)
|
||||
return (x + yy * z), yy.sum()
|
||||
|
||||
a = mx.random.normal((10, 10))
|
||||
b = mx.array(0.1)
|
||||
c = mx.random.normal((10, 10))
|
||||
mx.eval(a, b, c)
|
||||
fc = mx.compile(fun)
|
||||
d = fc(a, b, c)
|
||||
|
||||
s = StringIO()
|
||||
mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1])
|
||||
s.seek(0)
|
||||
s = s.read()
|
||||
|
||||
self.assertTrue("CompiledBroadcastMultiplyAdd" in s)
|
||||
d_hat = fun(a, b, c)
|
||||
self.assertTrue(mx.allclose(d[0], d_hat[0]))
|
||||
self.assertTrue(mx.allclose(d[1], d_hat[1]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Reference in New Issue
Block a user