Fix compile multi capture (#2678)

* fix compile when compiling multiple lambdas with the same capture

* add test
This commit is contained in:
Awni Hannun
2025-11-03 06:33:43 -08:00
committed by GitHub
parent 78678de0cd
commit 93d76b0f30
5 changed files with 139 additions and 41 deletions

View File

@@ -1134,6 +1134,30 @@ class TestCompile(mlx_tests.MLXTestCase):
a = fun2(mx.array(-1.0))
self.assertEqual(a.item(), 1.0)
def test_multiple_compile_same_capture(self):
def fun(do_compile):
t = mx.ones((10,))
u = (1.0 - t) * 0.0 + t * 3.0
o = mx.ones((6,))
b = o[:, None] * u
c = b * mx.ones_like(u)
a = mx.ones((6,))
if do_compile:
d = mx.compile(lambda x: x @ b)(a)
e = mx.compile(lambda x: x @ c.T)(d)
else:
d = a @ b
e = d @ c.T
return e
out = fun(True)
mx.eval(out)
expected = fun(False)
self.assertTrue(mx.allclose(out, expected))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()