mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
fix
This commit is contained in:
@@ -189,11 +189,14 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<int> open_streams;
|
||||
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.back());
|
||||
tape.pop_back();
|
||||
|
||||
auto stream = arr.primitive().stream();
|
||||
open_streams.insert(stream.index);
|
||||
|
||||
if (async) {
|
||||
// Lookup corresponding event
|
||||
@@ -234,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
(get_active_memory() > get_memory_limit() &&
|
||||
scheduler::n_active_tasks() > 0)) {
|
||||
// Commit any open streams
|
||||
for (auto& [_, e] : events) {
|
||||
if (e.stream().device == Device::gpu) {
|
||||
gpu::finalize(e.stream());
|
||||
for (auto i : open_streams) {
|
||||
auto s = get_stream(i);
|
||||
if (s.device == Device::gpu) {
|
||||
gpu::finalize(s);
|
||||
}
|
||||
}
|
||||
scheduler::wait_for_one();
|
||||
@@ -270,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
|
||||
// Signal the event in its stream
|
||||
for (auto& [_, e] : events) {
|
||||
auto s = e.stream();
|
||||
e.signal(s);
|
||||
for (auto i : open_streams) {
|
||||
auto s = get_stream(i);
|
||||
if (auto e = events.find(stream.index); e != events.end()) {
|
||||
e->second.signal(s);
|
||||
}
|
||||
if (s.device == Device::gpu) {
|
||||
gpu::finalize(s);
|
||||
}
|
||||
|
Reference in New Issue
Block a user