Async eval (#972)

This commit is contained in:
Awni Hannun
2024-04-09 18:34:00 -07:00
committed by GitHub
parent fffe072028
commit 99abb9eff4
5 changed files with 70 additions and 3 deletions

View File

@@ -38,7 +38,13 @@ class Synchronizer : public Primitive {
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
void eval(std::vector<array> outputs) {
std::shared_future<void> async_eval(std::vector<array> outputs) {
static std::shared_future<void> global_synchronizer;
// Catch up with previous async eval if needed
if (global_synchronizer.valid()) {
global_synchronizer.wait();
}
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
@@ -152,8 +158,12 @@ void eval(std::vector<array> outputs) {
scheduler::enqueue(stream, std::move(task));
}
}
global_synchronizer = std::move(deps[synchronizer.id()]);
return global_synchronizer;
}
deps[synchronizer.id()].wait();
void eval(std::vector<array> outputs) {
async_eval(std::move(outputs)).wait();
}
std::pair<std::vector<array>, std::vector<array>> vjp(