mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 22:44:38 +08:00
Fix deep recursion with siblings (#1462)
* fix recursion with siblings * fix * add test * increase tol
This commit is contained in:
@@ -1868,6 +1868,33 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
int(a)
|
||||
|
||||
def test_deep_graphs(self):
|
||||
# The following tests should simply run cleanly without a segfault or
|
||||
# crash due to exceeding recursion depth limits.
|
||||
|
||||
# Deep graph destroyed without eval
|
||||
x = mx.array([1.0, 2.0])
|
||||
for _ in range(100_000):
|
||||
x = mx.sin(x)
|
||||
del x
|
||||
|
||||
# Duplicate input deep graph destroyed without eval
|
||||
x = mx.array([1.0, 2.0])
|
||||
for _ in range(100_000):
|
||||
x = x + x
|
||||
|
||||
# Deep graph with siblings destroyed without eval
|
||||
x = mx.array([1, 2])
|
||||
for _ in range(100_000):
|
||||
x = mx.concatenate(mx.split(x, 2))
|
||||
del x
|
||||
|
||||
# Deep graph with eval
|
||||
x = mx.array([1.0, 2.0])
|
||||
for _ in range(100_000):
|
||||
x = mx.sin(x)
|
||||
mx.eval(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user