Remove depth traversal (#813)

* no depth traversal

* counter outside loop
This commit is contained in:
Awni Hannun
2024-03-09 20:21:32 -08:00
committed by GitHub
parent 28301807c2
commit a4d290adb9
3 changed files with 17 additions and 60 deletions

View File

@@ -38,7 +38,7 @@ class Synchronizer : public Primitive {
int detail::InTracing::tracing_counter{0};
void eval(const std::vector<array>& outputs) {
std::function<void(const array&, bool)> recurse;
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
@@ -55,56 +55,32 @@ void eval(const std::vector<array>& outputs) {
auto synchronizer =
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
recurse = [&](const array& a, bool largest_branch_first) {
size_t depth_counter = 0;
recurse = [&](const array& a) {
if (depth_counter > max_graph_depth) {
throw std::runtime_error(
"[eval] Graph depth exceeded maximum allowed limit."
" Try evaluating the graph more frequently.");
}
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
// If the input is being computed on a different stream, we need to manage
// the dependency.
auto check_dependency = [&](const array& in) {
// Recurse to the largest or smallest branch first.
depth_counter++;
for (auto& in : a.inputs()) {
recurse(in);
if (!in.is_evaled()) {
// If the input is being computed on a different stream, we need to
// manage the dependency.
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.primitive_id(), std::shared_future<void>{}});
}
}
};
// Recurse to the largest or smallest branch first.
size_t num_inputs = a.inputs().size();
if (num_inputs == 1) {
auto& in = a.inputs()[0];
recurse(in, true);
check_dependency(in);
} else if (num_inputs == 2) {
auto depth_1 = a.inputs()[0].graph_depth();
auto depth_2 = a.inputs()[1].graph_depth();
auto& in1 = a.inputs()[static_cast<int>(
!((depth_1 > depth_2) == largest_branch_first))];
auto& in2 = a.inputs()[static_cast<int>(
((depth_1 > depth_2) == largest_branch_first))];
recurse(in1, true);
check_dependency(in1);
recurse(in2, true);
check_dependency(in2);
} else if (num_inputs > 2) {
std::vector<int> recursion_order(a.inputs().size());
std::iota(recursion_order.begin(), recursion_order.end(), 0);
std::sort(
recursion_order.begin(),
recursion_order.end(),
[&a, largest_branch_first](int i, int j) {
auto depth_i = a.inputs()[i].graph_depth();
auto depth_j = a.inputs()[j].graph_depth();
return largest_branch_first ? depth_i > depth_j : depth_j < depth_i;
});
for (int idx : recursion_order) {
auto& in = a.inputs()[idx];
recurse(in, true);
check_dependency(in);
}
}
depth_counter--;
cache.insert(id);
for (auto& s : a.siblings()) {
@@ -119,12 +95,7 @@ void eval(const std::vector<array>& outputs) {
}
};
if (synchronizer.graph_depth() > max_graph_depth) {
throw std::runtime_error(
"[eval] Graph depth exceeded maximum allowed limit."
" Try evaluating the graph more frequently.");
}
recurse(synchronizer, false);
recurse(synchronizer);
uintptr_t synch_id = synchronizer.primitive_id();
deps.insert({synch_id, std::shared_future<void>{}});