diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 35a5322fe..265ef5cdd 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -189,11 +189,14 @@ array eval_impl(std::vector outputs, bool async) { } } + std::unordered_set open_streams; + while (!tape.empty()) { auto arr = std::move(tape.back()); tape.pop_back(); auto stream = arr.primitive().stream(); + open_streams.insert(stream.index); if (async) { // Lookup corresponding event @@ -234,9 +237,10 @@ array eval_impl(std::vector outputs, bool async) { (get_active_memory() > get_memory_limit() && scheduler::n_active_tasks() > 0)) { // Commit any open streams - for (auto& [_, e] : events) { - if (e.stream().device == Device::gpu) { - gpu::finalize(e.stream()); + for (auto i : open_streams) { + auto s = get_stream(i); + if (s.device == Device::gpu) { + gpu::finalize(s); } } scheduler::wait_for_one(); @@ -270,9 +274,11 @@ array eval_impl(std::vector outputs, bool async) { } // Signal the event in its stream - for (auto& [_, e] : events) { - auto s = e.stream(); - e.signal(s); + for (auto i : open_streams) { + auto s = get_stream(i); + if (auto e = events.find(stream.index); e != events.end()) { + e->second.signal(s); + } if (s.device == Device::gpu) { gpu::finalize(s); }