Use a dummy primitive to only sync with one output

This commit is contained in:
Awni Hannun 2024-01-13 13:08:19 -08:00
parent 6e81c3e164
commit 9051fa1eaa

View File

@ -18,6 +18,17 @@
namespace mlx::core {
/* This class is only meant to be used in eval
* for synchronizing with the main thread. */
class Synchronizer : public Primitive {
public:
explicit Synchronizer(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
void print(std::ostream&) override {}
};
// Initialize the static tracing counter from transforms_impl.h .
//
// This is used to implement the in_tracing() function the returns true if we
@ -193,12 +204,11 @@ void eval(const std::vector<array>& outputs) {
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
std::set<std::uintptr_t> output_primitives;
for (auto& arr : outputs) {
if (!arr.is_evaled()) {
output_primitives.insert(arr.primitive_id());
}
}
auto synchronizer = array(
{},
bool_,
std::make_unique<Synchronizer>(default_stream(default_device())),
outputs);
recurse = [&](const array& a) {
auto id = a.id();
@ -206,12 +216,6 @@ void eval(const std::vector<array>& outputs) {
return;
}
for (auto in : a.inputs()) {
// Pop fake outputs from the output set so we know who to synchronize
// with at the end
if (auto it = output_primitives.find(in.primitive_id());
it != output_primitives.end()) {
output_primitives.erase(it);
}
recurse(in);
// If one of the inputs is being computed on a different
// stream, we need to manage the dependency.
@ -234,20 +238,9 @@ void eval(const std::vector<array>& outputs) {
}
};
// We have to store the output primitive ids because the arrays are
// detached during eval and we need to use them for synchronization
// at the end of this function
std::vector<std::uintptr_t> output_primitive_ids;
for (auto& arr : outputs) {
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
recurse(arr);
}
}
// Insert output dependencies
for (auto pid : output_primitives) {
deps.insert({pid, std::shared_future<void>{}});
}
recurse(synchronizer);
uintptr_t synch_id = synchronizer.primitive_id();
deps.insert({synch_id, std::shared_future<void>{}});
std::vector<std::shared_ptr<std::promise<void>>> ps;
while (!tape.empty()) {
@ -305,9 +298,8 @@ void eval(const std::vector<array>& outputs) {
scheduler::enqueue(stream, std::move(task));
}
}
for (auto id : output_primitives) {
deps[id].wait();
}
deps[synch_id].wait();
}
std::pair<std::vector<array>, std::vector<array>> vjp(