diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index eb7f1b58a..39a49230d 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -14,6 +14,10 @@ Event::Event(Stream stream) : stream_(stream) { auto p = metal::new_scoped_memory_pool(); event_ = std::shared_ptr( metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor); + if (event_ == nullptr) { + throw std::runtime_error( + "[Event::Event] Failed to create Metal shared event."); + } } void Event::wait() { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index d9e227ea3..bcf0cc09f 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -72,7 +72,12 @@ array eval_impl(std::vector outputs, bool async) { // Stream events for synchronization after eval std::unordered_map events; - events.emplace(stream.index, Event{stream}); + { + auto e = Event{stream}; + e.set_value(1); + synchronizer.attach_event(e); + events.emplace(stream.index, std::move(e)); + } { // Record the degree of each input @@ -184,21 +189,26 @@ array eval_impl(std::vector outputs, bool async) { } } + std::unordered_set open_streams; + while (!tape.empty()) { auto arr = std::move(tape.back()); tape.pop_back(); auto stream = arr.primitive().stream(); + open_streams.insert(stream.index); - // Lookup corresponding event - auto e = events.find(stream.index); - if (e == events.end()) { - e = events.emplace(stream.index, Event{stream}).first; - } - e->second.set_value(1); - arr.attach_event(e->second); - for (auto& s : arr.siblings()) { - s.attach_event(e->second); + if (async) { + // Lookup corresponding event + auto e = events.find(stream.index); + if (e == events.end()) { + e = events.emplace(stream.index, Event{stream}).first; + } + e->second.set_value(1); + arr.attach_event(e->second); + for (auto& s : arr.siblings()) { + s.attach_event(e->second); + } } for (auto& in : arr.inputs()) { @@ -227,9 +237,10 @@ array eval_impl(std::vector outputs, bool async) { (get_active_memory() > get_memory_limit() && scheduler::n_active_tasks() > 0)) { // Commit any open streams - for (auto& [_, e] : events) { - if (e.stream().device == Device::gpu) { - gpu::finalize(e.stream()); + for (auto i : open_streams) { + auto s = get_stream(i); + if (s.device == Device::gpu) { + gpu::finalize(s); } } scheduler::wait_for_one(); @@ -263,9 +274,11 @@ array eval_impl(std::vector outputs, bool async) { } // Signal the event in its stream - for (auto& [_, e] : events) { - auto s = e.stream(); - e.signal(s); + for (auto i : open_streams) { + auto s = get_stream(i); + if (auto e = events.find(i); e != events.end()) { + e->second.signal(s); + } if (s.device == Device::gpu) { gpu::finalize(s); } @@ -302,7 +315,7 @@ void eval(std::vector outputs) { return; } - eval_impl(std::move(outputs), false).event().wait(); + eval_impl(std::move(outputs), false).wait(); } std::pair, std::vector> vjp(