Fix compile when outputs change (#2648)

This commit is contained in:
Awni Hannun
2025-10-03 08:40:57 -07:00
committed by GitHub
parent 22a5da76c8
commit a7a94b29d7
2 changed files with 30 additions and 0 deletions

View File

@@ -1115,6 +1115,25 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(state[0].item(), 4)
def test_outputs_changing(self):
@mx.compile
def fun(x):
x = mx.abs(mx.negative(x))
y = mx.abs(x)
return x, y
@mx.compile
def fun2(x):
x = mx.abs(mx.negative(x))
y = mx.abs(x)
return y
a, b = fun(mx.array(-1.0))
mx.eval(a, b)
a = fun2(mx.array(-1.0))
self.assertEqual(a.item(), 1.0)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()