This commit is contained in:
Awni Hannun
2025-07-23 16:36:09 -07:00
parent 5a4f375c6c
commit b2280a1c41

View File

@@ -189,11 +189,14 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
}
std::unordered_set<int> open_streams;
while (!tape.empty()) {
auto arr = std::move(tape.back());
tape.pop_back();
auto stream = arr.primitive().stream();
open_streams.insert(stream.index);
if (async) {
// Lookup corresponding event
@@ -234,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
(get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0)) {
// Commit any open streams
for (auto& [_, e] : events) {
if (e.stream().device == Device::gpu) {
gpu::finalize(e.stream());
for (auto i : open_streams) {
auto s = get_stream(i);
if (s.device == Device::gpu) {
gpu::finalize(s);
}
}
scheduler::wait_for_one();
@@ -270,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
// Signal the event in its stream
for (auto& [_, e] : events) {
auto s = e.stream();
e.signal(s);
for (auto i : open_streams) {
auto s = get_stream(i);
if (auto e = events.find(stream.index); e != events.end()) {
e->second.signal(s);
}
if (s.device == Device::gpu) {
gpu::finalize(s);
}