Try a stack-based DFS for eval (#980)

* rebase

* nit

* fix eval in vmap
This commit is contained in:
Awni Hannun 2024-04-10 17:05:13 -07:00 committed by GitHub
parent 061cf9a4ce
commit 8580d997ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@
#include <numeric>
#include <set>
#include <sstream>
#include <stack>
#include <unordered_map>
#include <unordered_set>
@ -17,9 +18,6 @@
namespace mlx::core {
// Maximum allowed graph depth for eval
constexpr uint32_t max_graph_depth = 100'000;
/* This class is only meant to be used in eval
* for synchronizing with the main thread. */
class Synchronizer : public Primitive {
@ -44,8 +42,6 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
if (global_synchronizer.valid()) {
global_synchronizer.wait();
}
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
@ -62,47 +58,45 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
auto synchronizer = array(
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
size_t depth_counter = 0;
recurse = [&](const array& a) {
if (depth_counter > max_graph_depth) {
throw std::runtime_error(
"[eval] Graph depth exceeded maximum allowed limit."
" Try evaluating the graph more frequently.");
}
{
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
dfs.emplace(synchronizer, 0);
while (!dfs.empty()) {
auto& [a_ref, idx] = dfs.top();
auto& a = a_ref.get();
if (idx < a.inputs().size()) {
// Add an input, and continue
auto& in = a.inputs()[idx++];
if (!in.is_evaled()) {
if (!in.has_primitive()) {
throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive.");
}
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
// Recurse to the largest or smallest branch first.
depth_counter++;
for (auto& in : a.inputs()) {
recurse(in);
if (!in.is_evaled()) {
// If the input is being computed on a different stream, we need to
// manage the dependency.
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.output(0).id(), std::shared_future<void>{}});
// If the input is being computed on a different stream, we need to
// manage the dependency.
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.output(0).id(), std::shared_future<void>{}});
}
}
}
}
depth_counter--;
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
}
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
if (!a.has_primitive()) {
throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive.");
if (cache.find(in.id()) == cache.end()) {
dfs.emplace(in, 0);
cache.insert(in.id());
for (auto& s : in.siblings()) {
cache.insert(s.id());
}
}
continue;
}
tape.push(a);
}
};
recurse(synchronizer);
// All inputs are done being processed, process this array
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
tape.push(a);
}
dfs.pop();
}
}
deps.insert({synchronizer.id(), std::shared_future<void>{}});
std::vector<std::shared_ptr<std::promise<void>>> ps;