mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 23:21:16 +08:00
Fix deep recursion with siblings (#1462)
* fix recursion with siblings * fix * add test * increase tol
This commit is contained in:
parent
95d04805b3
commit
0070e1db40
@ -242,25 +242,35 @@ array::ArrayDesc::~ArrayDesc() {
|
|||||||
// This calls recursively the destructor and can result in stack overflow, we
|
// This calls recursively the destructor and can result in stack overflow, we
|
||||||
// instead put them in a vector and destroy them one at a time resulting in a
|
// instead put them in a vector and destroy them one at a time resulting in a
|
||||||
// max stack depth of 2.
|
// max stack depth of 2.
|
||||||
|
if (inputs.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
||||||
|
|
||||||
for (array& a : inputs) {
|
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
|
||||||
if (a.array_desc_.use_count() == 1) {
|
std::unordered_map<std::uintptr_t, array> input_map;
|
||||||
|
for (array& a : ad.inputs) {
|
||||||
|
if (a.array_desc_) {
|
||||||
|
input_map.insert({a.id(), a});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ad.inputs.clear();
|
||||||
|
for (auto& [_, a] : input_map) {
|
||||||
|
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
|
||||||
for_deletion.push_back(std::move(a.array_desc_));
|
for_deletion.push_back(std::move(a.array_desc_));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
append_deletable_inputs(*this);
|
||||||
|
|
||||||
while (!for_deletion.empty()) {
|
while (!for_deletion.empty()) {
|
||||||
// top is going to be deleted at the end of the block *after* the arrays
|
// top is going to be deleted at the end of the block *after* the arrays
|
||||||
// with inputs have been moved into the vector
|
// with inputs have been moved into the vector
|
||||||
auto top = std::move(for_deletion.back());
|
auto top = std::move(for_deletion.back());
|
||||||
for_deletion.pop_back();
|
for_deletion.pop_back();
|
||||||
|
append_deletable_inputs(*top);
|
||||||
for (array& a : top->inputs) {
|
|
||||||
if (a.array_desc_.use_count() == 1) {
|
|
||||||
for_deletion.push_back(std::move(a.array_desc_));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1868,6 +1868,33 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
int(a)
|
int(a)
|
||||||
|
|
||||||
|
def test_deep_graphs(self):
|
||||||
|
# The following tests should simply run cleanly without a segfault or
|
||||||
|
# crash due to exceeding recursion depth limits.
|
||||||
|
|
||||||
|
# Deep graph destroyed without eval
|
||||||
|
x = mx.array([1.0, 2.0])
|
||||||
|
for _ in range(100_000):
|
||||||
|
x = mx.sin(x)
|
||||||
|
del x
|
||||||
|
|
||||||
|
# Duplicate input deep graph destroyed without eval
|
||||||
|
x = mx.array([1.0, 2.0])
|
||||||
|
for _ in range(100_000):
|
||||||
|
x = x + x
|
||||||
|
|
||||||
|
# Deep graph with siblings destroyed without eval
|
||||||
|
x = mx.array([1, 2])
|
||||||
|
for _ in range(100_000):
|
||||||
|
x = mx.concatenate(mx.split(x, 2))
|
||||||
|
del x
|
||||||
|
|
||||||
|
# Deep graph with eval
|
||||||
|
x = mx.array([1.0, 2.0])
|
||||||
|
for _ in range(100_000):
|
||||||
|
x = mx.sin(x)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -902,7 +902,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
dw10 = (cotan[1::s, :-1:s] * x).sum()
|
dw10 = (cotan[1::s, :-1:s] * x).sum()
|
||||||
dw11 = (cotan[1::s, 1::s] * x).sum()
|
dw11 = (cotan[1::s, 1::s] * x).sum()
|
||||||
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
||||||
self.assertTrue(mx.allclose(dw, expected))
|
self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5))
|
||||||
|
|
||||||
def test_conv_groups_grad(self):
|
def test_conv_groups_grad(self):
|
||||||
def fn(x, w):
|
def fn(x, w):
|
||||||
|
Loading…
Reference in New Issue
Block a user