mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Fix deep recursion with siblings (#1462)
* fix recursion with siblings * fix * add test * increase tol
This commit is contained in:
		@@ -242,25 +242,35 @@ array::ArrayDesc::~ArrayDesc() {
 | 
				
			|||||||
  // This calls recursively the destructor and can result in stack overflow, we
 | 
					  // This calls recursively the destructor and can result in stack overflow, we
 | 
				
			||||||
  // instead put them in a vector and destroy them one at a time resulting in a
 | 
					  // instead put them in a vector and destroy them one at a time resulting in a
 | 
				
			||||||
  // max stack depth of 2.
 | 
					  // max stack depth of 2.
 | 
				
			||||||
 | 
					  if (inputs.empty()) {
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
 | 
					  std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  for (array& a : inputs) {
 | 
					  auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
 | 
				
			||||||
    if (a.array_desc_.use_count() == 1) {
 | 
					    std::unordered_map<std::uintptr_t, array> input_map;
 | 
				
			||||||
 | 
					    for (array& a : ad.inputs) {
 | 
				
			||||||
 | 
					      if (a.array_desc_) {
 | 
				
			||||||
 | 
					        input_map.insert({a.id(), a});
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    ad.inputs.clear();
 | 
				
			||||||
 | 
					    for (auto& [_, a] : input_map) {
 | 
				
			||||||
 | 
					      if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
 | 
				
			||||||
        for_deletion.push_back(std::move(a.array_desc_));
 | 
					        for_deletion.push_back(std::move(a.array_desc_));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  append_deletable_inputs(*this);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  while (!for_deletion.empty()) {
 | 
					  while (!for_deletion.empty()) {
 | 
				
			||||||
    // top is going to be deleted at the end of the block *after* the arrays
 | 
					    // top is going to be deleted at the end of the block *after* the arrays
 | 
				
			||||||
    // with inputs have been moved into the vector
 | 
					    // with inputs have been moved into the vector
 | 
				
			||||||
    auto top = std::move(for_deletion.back());
 | 
					    auto top = std::move(for_deletion.back());
 | 
				
			||||||
    for_deletion.pop_back();
 | 
					    for_deletion.pop_back();
 | 
				
			||||||
 | 
					    append_deletable_inputs(*top);
 | 
				
			||||||
    for (array& a : top->inputs) {
 | 
					 | 
				
			||||||
      if (a.array_desc_.use_count() == 1) {
 | 
					 | 
				
			||||||
        for_deletion.push_back(std::move(a.array_desc_));
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1868,6 +1868,33 @@ class TestArray(mlx_tests.MLXTestCase):
 | 
				
			|||||||
        with self.assertRaises(ValueError):
 | 
					        with self.assertRaises(ValueError):
 | 
				
			||||||
            int(a)
 | 
					            int(a)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_deep_graphs(self):
 | 
				
			||||||
 | 
					        # The following tests should simply run cleanly without a segfault or
 | 
				
			||||||
 | 
					        # crash due to exceeding recursion depth limits.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Deep graph destroyed without eval
 | 
				
			||||||
 | 
					        x = mx.array([1.0, 2.0])
 | 
				
			||||||
 | 
					        for _ in range(100_000):
 | 
				
			||||||
 | 
					            x = mx.sin(x)
 | 
				
			||||||
 | 
					        del x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Duplicate input deep graph destroyed without eval
 | 
				
			||||||
 | 
					        x = mx.array([1.0, 2.0])
 | 
				
			||||||
 | 
					        for _ in range(100_000):
 | 
				
			||||||
 | 
					            x = x + x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Deep graph with siblings destroyed without eval
 | 
				
			||||||
 | 
					        x = mx.array([1, 2])
 | 
				
			||||||
 | 
					        for _ in range(100_000):
 | 
				
			||||||
 | 
					            x = mx.concatenate(mx.split(x, 2))
 | 
				
			||||||
 | 
					        del x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Deep graph with eval
 | 
				
			||||||
 | 
					        x = mx.array([1.0, 2.0])
 | 
				
			||||||
 | 
					        for _ in range(100_000):
 | 
				
			||||||
 | 
					            x = mx.sin(x)
 | 
				
			||||||
 | 
					        mx.eval(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    unittest.main()
 | 
					    unittest.main()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -902,7 +902,7 @@ class TestConv(mlx_tests.MLXTestCase):
 | 
				
			|||||||
            dw10 = (cotan[1::s, :-1:s] * x).sum()
 | 
					            dw10 = (cotan[1::s, :-1:s] * x).sum()
 | 
				
			||||||
            dw11 = (cotan[1::s, 1::s] * x).sum()
 | 
					            dw11 = (cotan[1::s, 1::s] * x).sum()
 | 
				
			||||||
            expected = mx.array([[dw00, dw01], [dw10, dw11]])
 | 
					            expected = mx.array([[dw00, dw01], [dw10, dw11]])
 | 
				
			||||||
            self.assertTrue(mx.allclose(dw, expected))
 | 
					            self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_conv_groups_grad(self):
 | 
					    def test_conv_groups_grad(self):
 | 
				
			||||||
        def fn(x, w):
 | 
					        def fn(x, w):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user