diff --git a/mlx/array.h b/mlx/array.h index 266353196..b142f90c8 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -256,6 +256,17 @@ class array { array_desc_->position = position; } + /** The i-th output of the array's primitive. */ + const array& output(int i) const { + if (i == array_desc_->position) { + return *this; + } else if (i < array_desc_->position) { + return siblings()[i]; + } else { + return siblings()[i + 1]; + } + }; + /** The outputs of the array's primitive (i.e. this array and * its siblings) in the order the primitive expects. */ std::vector outputs() const { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 107181f78..42c9213f6 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -76,7 +76,7 @@ void eval(std::vector outputs) { // If the input is being computed on a different stream, we need to // manage the dependency. if (a.primitive().stream() != in.primitive().stream()) { - deps.insert({in.primitive_id(), std::shared_future{}}); + deps.insert({in.output(0).id(), std::shared_future{}}); } } } @@ -96,8 +96,7 @@ void eval(std::vector outputs) { }; recurse(synchronizer); - uintptr_t synch_id = synchronizer.primitive_id(); - deps.insert({synch_id, std::shared_future{}}); + deps.insert({synchronizer.id(), std::shared_future{}}); std::vector>> ps; while (!tape.empty()) { @@ -113,13 +112,12 @@ void eval(std::vector outputs) { auto stream = arr.primitive().stream(); std::vector> arr_deps; for (auto& in : arr.inputs()) { - // TODO that's a bug - if (auto it = deps.find(in.primitive_id()); it != deps.end()) { + if (auto it = deps.find(in.output(0).id()); it != deps.end()) { arr_deps.push_back(it->second); } } std::shared_ptr> p; - if (auto it = deps.find(arr.primitive_id()); it != deps.end()) { + if (auto it = deps.find(arr.output(0).id()); it != deps.end()) { p = std::make_unique>(); ps.push_back(p); it->second = p->get_future().share(); @@ -154,7 +152,7 @@ void eval(std::vector outputs) { } } - deps[synch_id].wait(); + deps[synchronizer.id()].wait(); } std::pair, std::vector> vjp(