Fix array is_available race cases (#1468)

This commit is contained in:
Awni Hannun
2024-10-07 19:13:50 -07:00
committed by GitHub
parent 9b12093739
commit 3274c6a087
9 changed files with 77 additions and 23 deletions

View File

@@ -79,7 +79,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
continue;
}
if (!in.is_available()) {
if (in.status() == array::Status::unscheduled) {
if (async && in.is_tracer()) {
throw std::invalid_argument(
"[async_eval] Not allowed inside a graph transformation.");
@@ -115,7 +115,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
// All inputs are done being processed, process this array
if (a.is_available() && !a.is_tracer() && a.has_primitive()) {
if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
a.has_primitive()) {
// If the array is evaluated and is no longer a tracer, detach it
a.detach();
} else if (a.status() == array::Status::unscheduled) {
@@ -208,14 +209,12 @@ void eval(std::vector<array> outputs) {
return x.status() == array::Status::unscheduled;
})) {
for (auto& x : outputs) {
if (!x.is_available()) {
x.event().wait();
}
x.wait();
}
return;
}
eval_impl(std::move(outputs), false).event().wait();
eval_impl(std::move(outputs), false).wait();
}
std::pair<std::vector<array>, std::vector<array>> vjp(