Fix race in multi-stream eval (#911)

* maybe fix race

* comment
This commit is contained in:
Awni Hannun 2024-03-26 16:36:36 -07:00 committed by GitHub
parent 29221fa238
commit dc175f08d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 7 deletions

View File

@ -256,6 +256,17 @@ class array {
array_desc_->position = position;
}
/** The i-th output of the array's primitive. */
const array& output(int i) const {
if (i == array_desc_->position) {
return *this;
} else if (i < array_desc_->position) {
return siblings()[i];
} else {
return siblings()[i + 1];
}
};
/** The outputs of the array's primitive (i.e. this array and
* its siblings) in the order the primitive expects. */
std::vector<array> outputs() const {

View File

@ -76,7 +76,7 @@ void eval(std::vector<array> outputs) {
// If the input is being computed on a different stream, we need to
// manage the dependency.
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.primitive_id(), std::shared_future<void>{}});
deps.insert({in.output(0).id(), std::shared_future<void>{}});
}
}
}
@ -96,8 +96,7 @@ void eval(std::vector<array> outputs) {
};
recurse(synchronizer);
uintptr_t synch_id = synchronizer.primitive_id();
deps.insert({synch_id, std::shared_future<void>{}});
deps.insert({synchronizer.id(), std::shared_future<void>{}});
std::vector<std::shared_ptr<std::promise<void>>> ps;
while (!tape.empty()) {
@ -113,13 +112,12 @@ void eval(std::vector<array> outputs) {
auto stream = arr.primitive().stream();
std::vector<std::shared_future<void>> arr_deps;
for (auto& in : arr.inputs()) {
// TODO that's a bug
if (auto it = deps.find(in.primitive_id()); it != deps.end()) {
if (auto it = deps.find(in.output(0).id()); it != deps.end()) {
arr_deps.push_back(it->second);
}
}
std::shared_ptr<std::promise<void>> p;
if (auto it = deps.find(arr.primitive_id()); it != deps.end()) {
if (auto it = deps.find(arr.output(0).id()); it != deps.end()) {
p = std::make_unique<std::promise<void>>();
ps.push_back(p);
it->second = p->get_future().share();
@ -154,7 +152,7 @@ void eval(std::vector<array> outputs) {
}
}
deps[synch_id].wait();
deps[synchronizer.id()].wait();
}
std::pair<std::vector<array>, std::vector<array>> vjp(