mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-30 21:51:25 +08:00
Add a maximum graph depth (#797)
* add a maximum graph depth * remember how to use C++
This commit is contained in:
parent
7762e07fde
commit
1074674e32
@ -274,7 +274,7 @@ class array {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
|
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
|
||||||
uint16_t graph_depth() const {
|
uint32_t graph_depth() const {
|
||||||
return array_desc_->depth;
|
return array_desc_->depth;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -389,7 +389,7 @@ class array {
|
|||||||
uint32_t position{0};
|
uint32_t position{0};
|
||||||
|
|
||||||
// The depth of the array in the graph.
|
// The depth of the array in the graph.
|
||||||
uint16_t depth{0};
|
uint32_t depth{0};
|
||||||
|
|
||||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||||
|
|
||||||
|
@ -17,6 +17,9 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Maximum allowed graph depth for eval
|
||||||
|
constexpr uint32_t max_graph_depth = 100'000;
|
||||||
|
|
||||||
/* This class is only meant to be used in eval
|
/* This class is only meant to be used in eval
|
||||||
* for synchronizing with the main thread. */
|
* for synchronizing with the main thread. */
|
||||||
class Synchronizer : public Primitive {
|
class Synchronizer : public Primitive {
|
||||||
@ -116,6 +119,11 @@ void eval(const std::vector<array>& outputs) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (synchronizer.graph_depth() > max_graph_depth) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[eval] Graph depth exceeded maximum allowed limit."
|
||||||
|
" Try evaluating the graph more frequently.");
|
||||||
|
}
|
||||||
recurse(synchronizer, false);
|
recurse(synchronizer, false);
|
||||||
uintptr_t synch_id = synchronizer.primitive_id();
|
uintptr_t synch_id = synchronizer.primitive_id();
|
||||||
deps.insert({synch_id, std::shared_future<void>{}});
|
deps.insert({synch_id, std::shared_future<void>{}});
|
||||||
|
Loading…
Reference in New Issue
Block a user