Sync only with outputs we need to sync with (#447)

This commit is contained in:
Awni Hannun 2024-01-13 01:47:25 -08:00 committed by GitHub
parent 2e29d0815b
commit 6e81c3e164
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -193,12 +193,25 @@ void eval(const std::vector<array>& outputs) {
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
std::set<std::uintptr_t> output_primitives;
for (auto& arr : outputs) {
if (!arr.is_evaled()) {
output_primitives.insert(arr.primitive_id());
}
}
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (auto in : a.inputs()) {
// Pop fake outputs from the output set so we know who to synchronize
// with at the end
if (auto it = output_primitives.find(in.primitive_id());
it != output_primitives.end()) {
output_primitives.erase(it);
}
recurse(in);
// If one of the inputs is being computed on a different
// stream, we need to manage the dependency.
@ -228,15 +241,14 @@ void eval(const std::vector<array>& outputs) {
for (auto& arr : outputs) {
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
recurse(arr);
// Insert a dependency for every output to synchronize
// with at the end.
if (!arr.is_evaled() && deps.find(arr.primitive_id()) == deps.end()) {
deps.insert({arr.primitive_id(), std::shared_future<void>{}});
output_primitive_ids.push_back(arr.primitive_id());
}
}
}
// Insert output dependencies
for (auto pid : output_primitives) {
deps.insert({pid, std::shared_future<void>{}});
}
std::vector<std::shared_ptr<std::promise<void>>> ps;
while (!tape.empty()) {
auto arr = std::move(tape.front());
@ -293,7 +305,7 @@ void eval(const std::vector<array>& outputs) {
scheduler::enqueue(stream, std::move(task));
}
}
for (auto id : output_primitive_ids) {
for (auto id : output_primitives) {
deps[id].wait();
}
}