Fix compile fusion for multi-output edge cases (#950)

* Fix compile fusion for multi-output edge cases

* Add a test for multi-output compile
This commit is contained in:
Angelos Katharopoulos
2024-04-02 08:42:31 -07:00
committed by GitHub
parent 2427fa171e
commit 1a87dc5ea8
2 changed files with 49 additions and 18 deletions

View File

@@ -691,6 +691,19 @@ class TestCompile(mlx_tests.MLXTestCase):
out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))
self.assertEqual(out.item(), 10.0)
def test_compile_multi_output(self):
def fn(x):
ys = [x]
for i in range(5):
ys.append(ys[-1] + x)
return ys, mx.sum(ys[-1])
x = mx.ones(1, dtype=mx.int32)
y1 = mx.compile(fn)(x)[1]
y2 = fn(x)[1]
self.assertEqual(y1.item(), y2.item())
self.assertEqual(y1.item(), 6)
if __name__ == "__main__":
unittest.main()