From 6e81c3e164286b6852ce28deb2e42bdbbf7fb447 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 13 Jan 2024 01:47:25 -0800 Subject: [PATCH] Sync only with outputs we need to sync with (#447) --- mlx/transforms.cpp | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index bbd9dc5e1..8ca7c8f53 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -193,12 +193,25 @@ void eval(const std::vector& outputs) { std::unordered_set cache; std::unordered_map> deps; + std::set output_primitives; + for (auto& arr : outputs) { + if (!arr.is_evaled()) { + output_primitives.insert(arr.primitive_id()); + } + } + recurse = [&](const array& a) { auto id = a.id(); if (cache.find(id) != cache.end()) { return; } for (auto in : a.inputs()) { + // Pop fake outputs from the output set so we know who to synchronize + // with at the end + if (auto it = output_primitives.find(in.primitive_id()); + it != output_primitives.end()) { + output_primitives.erase(it); + } recurse(in); // If one of the inputs is being computed on a different // stream, we need to manage the dependency. @@ -228,15 +241,14 @@ void eval(const std::vector& outputs) { for (auto& arr : outputs) { if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) { recurse(arr); - // Insert a dependency for every output to synchronize - // with at the end. - if (!arr.is_evaled() && deps.find(arr.primitive_id()) == deps.end()) { - deps.insert({arr.primitive_id(), std::shared_future{}}); - output_primitive_ids.push_back(arr.primitive_id()); - } } } + // Insert output dependencies + for (auto pid : output_primitives) { + deps.insert({pid, std::shared_future{}}); + } + std::vector>> ps; while (!tape.empty()) { auto arr = std::move(tape.front()); @@ -293,7 +305,7 @@ void eval(const std::vector& outputs) { scheduler::enqueue(stream, std::move(task)); } } - for (auto id : output_primitive_ids) { + for (auto id : output_primitives) { deps[id].wait(); } }