mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 11:48:37 +08:00
Try a stack-based DFS for eval (#980)
* rebase * nit * fix eval in vmap
This commit is contained in:
parent
061cf9a4ce
commit
8580d997ff
@ -4,6 +4,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <stack>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
@ -17,9 +18,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Maximum allowed graph depth for eval
|
|
||||||
constexpr uint32_t max_graph_depth = 100'000;
|
|
||||||
|
|
||||||
/* This class is only meant to be used in eval
|
/* This class is only meant to be used in eval
|
||||||
* for synchronizing with the main thread. */
|
* for synchronizing with the main thread. */
|
||||||
class Synchronizer : public Primitive {
|
class Synchronizer : public Primitive {
|
||||||
@ -44,8 +42,6 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
|
|||||||
if (global_synchronizer.valid()) {
|
if (global_synchronizer.valid()) {
|
||||||
global_synchronizer.wait();
|
global_synchronizer.wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<void(const array&)> recurse;
|
|
||||||
std::queue<array> tape;
|
std::queue<array> tape;
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||||
@ -62,47 +58,45 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
|
|||||||
auto synchronizer = array(
|
auto synchronizer = array(
|
||||||
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
||||||
|
|
||||||
size_t depth_counter = 0;
|
{
|
||||||
recurse = [&](const array& a) {
|
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
|
||||||
if (depth_counter > max_graph_depth) {
|
dfs.emplace(synchronizer, 0);
|
||||||
throw std::runtime_error(
|
while (!dfs.empty()) {
|
||||||
"[eval] Graph depth exceeded maximum allowed limit."
|
auto& [a_ref, idx] = dfs.top();
|
||||||
" Try evaluating the graph more frequently.");
|
auto& a = a_ref.get();
|
||||||
}
|
if (idx < a.inputs().size()) {
|
||||||
|
// Add an input, and continue
|
||||||
auto id = a.id();
|
auto& in = a.inputs()[idx++];
|
||||||
if (cache.find(id) != cache.end()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurse to the largest or smallest branch first.
|
|
||||||
depth_counter++;
|
|
||||||
for (auto& in : a.inputs()) {
|
|
||||||
recurse(in);
|
|
||||||
if (!in.is_evaled()) {
|
if (!in.is_evaled()) {
|
||||||
|
if (!in.has_primitive()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[eval] Attempting to eval an array without a primitive.");
|
||||||
|
}
|
||||||
|
|
||||||
// If the input is being computed on a different stream, we need to
|
// If the input is being computed on a different stream, we need to
|
||||||
// manage the dependency.
|
// manage the dependency.
|
||||||
if (a.primitive().stream() != in.primitive().stream()) {
|
if (a.primitive().stream() != in.primitive().stream()) {
|
||||||
deps.insert({in.output(0).id(), std::shared_future<void>{}});
|
deps.insert({in.output(0).id(), std::shared_future<void>{}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
depth_counter--;
|
|
||||||
|
|
||||||
cache.insert(id);
|
if (cache.find(in.id()) == cache.end()) {
|
||||||
for (auto& s : a.siblings()) {
|
dfs.emplace(in, 0);
|
||||||
|
cache.insert(in.id());
|
||||||
|
for (auto& s : in.siblings()) {
|
||||||
cache.insert(s.id());
|
cache.insert(s.id());
|
||||||
}
|
}
|
||||||
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
|
||||||
if (!a.has_primitive()) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[eval] Attempting to eval an array without a primitive.");
|
|
||||||
}
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// All inputs are done being processed, process this array
|
||||||
|
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
||||||
tape.push(a);
|
tape.push(a);
|
||||||
}
|
}
|
||||||
};
|
dfs.pop();
|
||||||
|
}
|
||||||
recurse(synchronizer);
|
}
|
||||||
deps.insert({synchronizer.id(), std::shared_future<void>{}});
|
deps.insert({synchronizer.id(), std::shared_future<void>{}});
|
||||||
|
|
||||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||||
|
Loading…
Reference in New Issue
Block a user