diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 8ca7c8f53..6135a54f7 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -18,6 +18,17 @@ namespace mlx::core { +/* This class is only meant to be used in eval + * for synchronizing with the main thread. */ +class Synchronizer : public Primitive { + public: + explicit Synchronizer(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector&, std::vector&) override{}; + void eval_gpu(const std::vector&, std::vector&) override{}; + void print(std::ostream&) override {} +}; + // Initialize the static tracing counter from transforms_impl.h . // // This is used to implement the in_tracing() function the returns true if we @@ -193,25 +204,24 @@ void eval(const std::vector& outputs) { std::unordered_set cache; std::unordered_map> deps; - std::set output_primitives; - for (auto& arr : outputs) { - if (!arr.is_evaled()) { - output_primitives.insert(arr.primitive_id()); + // Make an effort to choose a good output stream + Stream stream = default_stream(default_device()); + for (auto& o : outputs) { + if (!o.is_evaled() && o.has_primitive()) { + stream = o.primitive().stream(); + break; } } + auto synchronizer = + array({}, bool_, std::make_unique(stream), outputs); + recurse = [&](const array& a) { auto id = a.id(); if (cache.find(id) != cache.end()) { return; } for (auto in : a.inputs()) { - // Pop fake outputs from the output set so we know who to synchronize - // with at the end - if (auto it = output_primitives.find(in.primitive_id()); - it != output_primitives.end()) { - output_primitives.erase(it); - } recurse(in); // If one of the inputs is being computed on a different // stream, we need to manage the dependency. @@ -234,20 +244,9 @@ void eval(const std::vector& outputs) { } }; - // We have to store the output primitive ids because the arrays are - // detached during eval and we need to use them for synchronization - // at the end of this function - std::vector output_primitive_ids; - for (auto& arr : outputs) { - if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) { - recurse(arr); - } - } - - // Insert output dependencies - for (auto pid : output_primitives) { - deps.insert({pid, std::shared_future{}}); - } + recurse(synchronizer); + uintptr_t synch_id = synchronizer.primitive_id(); + deps.insert({synch_id, std::shared_future{}}); std::vector>> ps; while (!tape.empty()) { @@ -305,9 +304,8 @@ void eval(const std::vector& outputs) { scheduler::enqueue(stream, std::move(task)); } } - for (auto id : output_primitives) { - deps[id].wait(); - } + + deps[synch_id].wait(); } std::pair, std::vector> vjp(