Add a maximum graph depth (#797)

* add a maximum graph depth

* remember how to use C++
This commit is contained in:
Awni Hannun 2024-03-06 15:39:00 -08:00 committed by GitHub
parent 7762e07fde
commit 1074674e32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 2 deletions

View File

@ -274,7 +274,7 @@ class array {
};
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
uint32_t graph_depth() const {
return array_desc_->depth;
}
@ -389,7 +389,7 @@ class array {
uint32_t position{0};
// The depth of the array in the graph.
uint16_t depth{0};
uint32_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);

View File

@ -17,6 +17,9 @@
namespace mlx::core {
// Maximum allowed graph depth for eval
constexpr uint32_t max_graph_depth = 100'000;
/* This class is only meant to be used in eval
* for synchronizing with the main thread. */
class Synchronizer : public Primitive {
@ -116,6 +119,11 @@ void eval(const std::vector<array>& outputs) {
}
};
if (synchronizer.graph_depth() > max_graph_depth) {
throw std::runtime_error(
"[eval] Graph depth exceeded maximum allowed limit."
" Try evaluating the graph more frequently.");
}
recurse(synchronizer, false);
uintptr_t synch_id = synchronizer.primitive_id();
deps.insert({synch_id, std::shared_future<void>{}});