mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-29 21:11:16 +08:00
Add a maximum graph depth (#797)
* add a maximum graph depth * remember how to use C++
This commit is contained in:
parent
7762e07fde
commit
1074674e32
@ -274,7 +274,7 @@ class array {
|
||||
};
|
||||
|
||||
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
|
||||
uint16_t graph_depth() const {
|
||||
uint32_t graph_depth() const {
|
||||
return array_desc_->depth;
|
||||
}
|
||||
|
||||
@ -389,7 +389,7 @@ class array {
|
||||
uint32_t position{0};
|
||||
|
||||
// The depth of the array in the graph.
|
||||
uint16_t depth{0};
|
||||
uint32_t depth{0};
|
||||
|
||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||
|
||||
|
@ -17,6 +17,9 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Maximum allowed graph depth for eval
|
||||
constexpr uint32_t max_graph_depth = 100'000;
|
||||
|
||||
/* This class is only meant to be used in eval
|
||||
* for synchronizing with the main thread. */
|
||||
class Synchronizer : public Primitive {
|
||||
@ -116,6 +119,11 @@ void eval(const std::vector<array>& outputs) {
|
||||
}
|
||||
};
|
||||
|
||||
if (synchronizer.graph_depth() > max_graph_depth) {
|
||||
throw std::runtime_error(
|
||||
"[eval] Graph depth exceeded maximum allowed limit."
|
||||
" Try evaluating the graph more frequently.");
|
||||
}
|
||||
recurse(synchronizer, false);
|
||||
uintptr_t synch_id = synchronizer.primitive_id();
|
||||
deps.insert({synch_id, std::shared_future<void>{}});
|
||||
|
Loading…
Reference in New Issue
Block a user