From 0070e1db40759fa098ccf2b374cb9794dee32e38 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 7 Oct 2024 06:15:33 -0700 Subject: [PATCH] Fix deep recursion with siblings (#1462) * fix recursion with siblings * fix * add test * increase tol --- mlx/array.cpp | 30 ++++++++++++++++++++---------- python/tests/test_array.py | 27 +++++++++++++++++++++++++++ python/tests/test_conv.py | 2 +- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index d5c141e70..99c6c23b7 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -242,25 +242,35 @@ array::ArrayDesc::~ArrayDesc() { // This calls recursively the destructor and can result in stack overflow, we // instead put them in a vector and destroy them one at a time resulting in a // max stack depth of 2. + if (inputs.empty()) { + return; + } + std::vector> for_deletion; - for (array& a : inputs) { - if (a.array_desc_.use_count() == 1) { - for_deletion.push_back(std::move(a.array_desc_)); + auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) { + std::unordered_map input_map; + for (array& a : ad.inputs) { + if (a.array_desc_) { + input_map.insert({a.id(), a}); + } } - } + ad.inputs.clear(); + for (auto& [_, a] : input_map) { + if (a.array_desc_.use_count() <= a.siblings().size() + 1) { + for_deletion.push_back(std::move(a.array_desc_)); + } + } + }; + + append_deletable_inputs(*this); while (!for_deletion.empty()) { // top is going to be deleted at the end of the block *after* the arrays // with inputs have been moved into the vector auto top = std::move(for_deletion.back()); for_deletion.pop_back(); - - for (array& a : top->inputs) { - if (a.array_desc_.use_count() == 1) { - for_deletion.push_back(std::move(a.array_desc_)); - } - } + append_deletable_inputs(*top); } } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 9b9cc7e21..cbbcd547c 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1868,6 +1868,33 @@ class TestArray(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): int(a) + def test_deep_graphs(self): + # The following tests should simply run cleanly without a segfault or + # crash due to exceeding recursion depth limits. + + # Deep graph destroyed without eval + x = mx.array([1.0, 2.0]) + for _ in range(100_000): + x = mx.sin(x) + del x + + # Duplicate input deep graph destroyed without eval + x = mx.array([1.0, 2.0]) + for _ in range(100_000): + x = x + x + + # Deep graph with siblings destroyed without eval + x = mx.array([1, 2]) + for _ in range(100_000): + x = mx.concatenate(mx.split(x, 2)) + del x + + # Deep graph with eval + x = mx.array([1.0, 2.0]) + for _ in range(100_000): + x = mx.sin(x) + mx.eval(x) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index f6bff01cb..e446e1df8 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -902,7 +902,7 @@ class TestConv(mlx_tests.MLXTestCase): dw10 = (cotan[1::s, :-1:s] * x).sum() dw11 = (cotan[1::s, 1::s] * x).sum() expected = mx.array([[dw00, dw01], [dw10, dw11]]) - self.assertTrue(mx.allclose(dw, expected)) + self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5)) def test_conv_groups_grad(self): def fn(x, w):