Add priority-queue eval

This commit is contained in:
Angelos Katharopoulos 2024-01-13 19:58:20 -08:00
parent fd94be28ea
commit 5b46b9bc52

View File

@ -199,10 +199,23 @@ void simplify(const std::vector<array>& outputs) {
} }
void eval(const std::vector<array>& outputs) { void eval(const std::vector<array>& outputs) {
std::function<void(const array&)> recurse; std::function<int(const array&)> recurse;
std::queue<array> tape; std::unordered_map<std::uintptr_t, int> cache;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps; std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
struct ArrayWithPriority {
int depth;
int order;
array x;
};
auto cmp = [](const ArrayWithPriority& a, const ArrayWithPriority& b) {
return (a.depth > b.depth) || (a.depth == b.depth && a.order > b.order);
};
std::priority_queue<
ArrayWithPriority,
std::vector<ArrayWithPriority>,
decltype(cmp)>
tape(cmp);
int order;
// Make an effort to choose a good output stream // Make an effort to choose a good output stream
Stream stream = default_stream(default_device()); Stream stream = default_stream(default_device());
@ -218,11 +231,12 @@ void eval(const std::vector<array>& outputs) {
recurse = [&](const array& a) { recurse = [&](const array& a) {
auto id = a.id(); auto id = a.id();
if (cache.find(id) != cache.end()) { if (auto it = cache.find(id); it != cache.end()) {
return; return it->second;
} }
int input_depth = 0;
for (auto in : a.inputs()) { for (auto in : a.inputs()) {
recurse(in); input_depth = std::max(input_depth, recurse(in));
// If one of the inputs is being computed on a different // If one of the inputs is being computed on a different
// stream, we need to manage the dependency. // stream, we need to manage the dependency.
if (!in.is_evaled()) { if (!in.is_evaled()) {
@ -231,17 +245,18 @@ void eval(const std::vector<array>& outputs) {
} }
} }
} }
cache.insert(id); cache.insert({id, input_depth + 1});
for (auto& s : a.siblings()) { for (auto& s : a.siblings()) {
cache.insert(s.id()); cache.insert({s.id(), input_depth + 1});
} }
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) { if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
if (!a.has_primitive()) { if (!a.has_primitive()) {
throw std::invalid_argument( throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive."); "[eval] Attempting to eval an array without a primitive.");
} }
tape.push(a); tape.push({input_depth + 1, order++, a});
} }
return input_depth + 1;
}; };
recurse(synchronizer); recurse(synchronizer);
@ -250,7 +265,8 @@ void eval(const std::vector<array>& outputs) {
std::vector<std::shared_ptr<std::promise<void>>> ps; std::vector<std::shared_ptr<std::promise<void>>> ps;
while (!tape.empty()) { while (!tape.empty()) {
auto arr = std::move(tape.front()); auto val = std::move(tape.top());
auto arr = std::move(val.x);
tape.pop(); tape.pop();
if (arr.is_evaled()) { if (arr.is_evaled()) {
if (!arr.is_tracer() && arr.has_primitive()) { if (!arr.is_tracer() && arr.has_primitive()) {