This commit is contained in:
Awni Hannun
2025-10-16 07:30:56 -07:00
parent c473719b23
commit f8b6f8a3dc
3 changed files with 73 additions and 39 deletions

View File

@@ -1134,6 +1134,30 @@ class TestCompile(mlx_tests.MLXTestCase):
a = fun2(mx.array(-1.0))
self.assertEqual(a.item(), 1.0)
def test_multiple_compile_same_capture(self):
def fun(do_compile):
t = mx.ones((10,))
u = (1.0 - t) * 0.0 + t * 3.0
o = mx.ones((6,))
b = o[:, None] * u
c = b * mx.ones_like(u)
a = mx.ones((6,))
if do_compile:
d = mx.compile(lambda x: x @ b)(a)
e = mx.compile(lambda x: x @ c.T)(d)
else:
d = a @ b
e = d @ c.T
return e
out = fun(True)
mx.eval(out)
expected = fun(False)
self.assertTrue(mx.allclose(out, expected))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()