mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 18:56:39 +08:00
Try a stack-based DFS for eval (#980)
* rebase * nit * fix eval in vmap
This commit is contained in:
parent
061cf9a4ce
commit
8580d997ff
@ -4,6 +4,7 @@
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
@ -17,9 +18,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Maximum allowed graph depth for eval
|
||||
constexpr uint32_t max_graph_depth = 100'000;
|
||||
|
||||
/* This class is only meant to be used in eval
|
||||
* for synchronizing with the main thread. */
|
||||
class Synchronizer : public Primitive {
|
||||
@ -44,8 +42,6 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
|
||||
if (global_synchronizer.valid()) {
|
||||
global_synchronizer.wait();
|
||||
}
|
||||
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||
@ -62,47 +58,45 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
|
||||
auto synchronizer = array(
|
||||
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
||||
|
||||
size_t depth_counter = 0;
|
||||
recurse = [&](const array& a) {
|
||||
if (depth_counter > max_graph_depth) {
|
||||
throw std::runtime_error(
|
||||
"[eval] Graph depth exceeded maximum allowed limit."
|
||||
" Try evaluating the graph more frequently.");
|
||||
}
|
||||
{
|
||||
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
|
||||
dfs.emplace(synchronizer, 0);
|
||||
while (!dfs.empty()) {
|
||||
auto& [a_ref, idx] = dfs.top();
|
||||
auto& a = a_ref.get();
|
||||
if (idx < a.inputs().size()) {
|
||||
// Add an input, and continue
|
||||
auto& in = a.inputs()[idx++];
|
||||
if (!in.is_evaled()) {
|
||||
if (!in.has_primitive()) {
|
||||
throw std::invalid_argument(
|
||||
"[eval] Attempting to eval an array without a primitive.");
|
||||
}
|
||||
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Recurse to the largest or smallest branch first.
|
||||
depth_counter++;
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse(in);
|
||||
if (!in.is_evaled()) {
|
||||
// If the input is being computed on a different stream, we need to
|
||||
// manage the dependency.
|
||||
if (a.primitive().stream() != in.primitive().stream()) {
|
||||
deps.insert({in.output(0).id(), std::shared_future<void>{}});
|
||||
// If the input is being computed on a different stream, we need to
|
||||
// manage the dependency.
|
||||
if (a.primitive().stream() != in.primitive().stream()) {
|
||||
deps.insert({in.output(0).id(), std::shared_future<void>{}});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
depth_counter--;
|
||||
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
||||
if (!a.has_primitive()) {
|
||||
throw std::invalid_argument(
|
||||
"[eval] Attempting to eval an array without a primitive.");
|
||||
if (cache.find(in.id()) == cache.end()) {
|
||||
dfs.emplace(in, 0);
|
||||
cache.insert(in.id());
|
||||
for (auto& s : in.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
tape.push(a);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(synchronizer);
|
||||
// All inputs are done being processed, process this array
|
||||
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
||||
tape.push(a);
|
||||
}
|
||||
dfs.pop();
|
||||
}
|
||||
}
|
||||
deps.insert({synchronizer.id(), std::shared_future<void>{}});
|
||||
|
||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||
|
Loading…
Reference in New Issue
Block a user