diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 8ca7c8f53..ed0c082e9 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -18,6 +18,17 @@ namespace mlx::core { +/* This class is only meant to be used in eval + * for synchronizing with the main thread. */ +class Synchronizer : public Primitive { + public: + explicit Synchronizer(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector&, std::vector&) override{}; + void eval_gpu(const std::vector&, std::vector&) override{}; + void print(std::ostream&) override {} +}; + // Initialize the static tracing counter from transforms_impl.h . // // This is used to implement the in_tracing() function the returns true if we @@ -193,12 +204,11 @@ void eval(const std::vector& outputs) { std::unordered_set cache; std::unordered_map> deps; - std::set output_primitives; - for (auto& arr : outputs) { - if (!arr.is_evaled()) { - output_primitives.insert(arr.primitive_id()); - } - } + auto synchronizer = array( + {}, + bool_, + std::make_unique(default_stream(default_device())), + outputs); recurse = [&](const array& a) { auto id = a.id(); @@ -206,12 +216,6 @@ void eval(const std::vector& outputs) { return; } for (auto in : a.inputs()) { - // Pop fake outputs from the output set so we know who to synchronize - // with at the end - if (auto it = output_primitives.find(in.primitive_id()); - it != output_primitives.end()) { - output_primitives.erase(it); - } recurse(in); // If one of the inputs is being computed on a different // stream, we need to manage the dependency. @@ -234,20 +238,9 @@ void eval(const std::vector& outputs) { } }; - // We have to store the output primitive ids because the arrays are - // detached during eval and we need to use them for synchronization - // at the end of this function - std::vector output_primitive_ids; - for (auto& arr : outputs) { - if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) { - recurse(arr); - } - } - - // Insert output dependencies - for (auto pid : output_primitives) { - deps.insert({pid, std::shared_future{}}); - } + recurse(synchronizer); + uintptr_t synch_id = synchronizer.primitive_id(); + deps.insert({synch_id, std::shared_future{}}); std::vector>> ps; while (!tape.empty()) { @@ -305,9 +298,8 @@ void eval(const std::vector& outputs) { scheduler::enqueue(stream, std::move(task)); } } - for (auto id : output_primitives) { - deps[id].wait(); - } + + deps[synch_id].wait(); } std::pair, std::vector> vjp(