Fix deep recursion with siblings (#1462)

* fix recursion with siblings

* fix

* add test

* increase tol
This commit is contained in:
Awni Hannun
2024-10-07 06:15:33 -07:00
committed by GitHub
parent 95d04805b3
commit 0070e1db40
3 changed files with 48 additions and 11 deletions

View File

@@ -1868,6 +1868,33 @@ class TestArray(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
int(a)
def test_deep_graphs(self):
# The following tests should simply run cleanly without a segfault or
# crash due to exceeding recursion depth limits.
# Deep graph destroyed without eval
x = mx.array([1.0, 2.0])
for _ in range(100_000):
x = mx.sin(x)
del x
# Duplicate input deep graph destroyed without eval
x = mx.array([1.0, 2.0])
for _ in range(100_000):
x = x + x
# Deep graph with siblings destroyed without eval
x = mx.array([1, 2])
for _ in range(100_000):
x = mx.concatenate(mx.split(x, 2))
del x
# Deep graph with eval
x = mx.array([1.0, 2.0])
for _ in range(100_000):
x = mx.sin(x)
mx.eval(x)
if __name__ == "__main__":
unittest.main()