mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix array is_available race cases (#1468)
This commit is contained in:
@@ -79,7 +79,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!in.is_available()) {
|
||||
if (in.status() == array::Status::unscheduled) {
|
||||
if (async && in.is_tracer()) {
|
||||
throw std::invalid_argument(
|
||||
"[async_eval] Not allowed inside a graph transformation.");
|
||||
@@ -115,7 +115,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
|
||||
// All inputs are done being processed, process this array
|
||||
if (a.is_available() && !a.is_tracer() && a.has_primitive()) {
|
||||
if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
|
||||
a.has_primitive()) {
|
||||
// If the array is evaluated and is no longer a tracer, detach it
|
||||
a.detach();
|
||||
} else if (a.status() == array::Status::unscheduled) {
|
||||
@@ -208,14 +209,12 @@ void eval(std::vector<array> outputs) {
|
||||
return x.status() == array::Status::unscheduled;
|
||||
})) {
|
||||
for (auto& x : outputs) {
|
||||
if (!x.is_available()) {
|
||||
x.event().wait();
|
||||
}
|
||||
x.wait();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
eval_impl(std::move(outputs), false).event().wait();
|
||||
eval_impl(std::move(outputs), false).wait();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
|
||||
Reference in New Issue
Block a user