mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 19:51:13 +08:00
Sync only with outputs we need to sync with (#447)
This commit is contained in:
parent
2e29d0815b
commit
6e81c3e164
@ -193,12 +193,25 @@ void eval(const std::vector<array>& outputs) {
|
|||||||
std::unordered_set<std::uintptr_t> cache;
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||||
|
|
||||||
|
std::set<std::uintptr_t> output_primitives;
|
||||||
|
for (auto& arr : outputs) {
|
||||||
|
if (!arr.is_evaled()) {
|
||||||
|
output_primitives.insert(arr.primitive_id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
recurse = [&](const array& a) {
|
recurse = [&](const array& a) {
|
||||||
auto id = a.id();
|
auto id = a.id();
|
||||||
if (cache.find(id) != cache.end()) {
|
if (cache.find(id) != cache.end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (auto in : a.inputs()) {
|
for (auto in : a.inputs()) {
|
||||||
|
// Pop fake outputs from the output set so we know who to synchronize
|
||||||
|
// with at the end
|
||||||
|
if (auto it = output_primitives.find(in.primitive_id());
|
||||||
|
it != output_primitives.end()) {
|
||||||
|
output_primitives.erase(it);
|
||||||
|
}
|
||||||
recurse(in);
|
recurse(in);
|
||||||
// If one of the inputs is being computed on a different
|
// If one of the inputs is being computed on a different
|
||||||
// stream, we need to manage the dependency.
|
// stream, we need to manage the dependency.
|
||||||
@ -228,13 +241,12 @@ void eval(const std::vector<array>& outputs) {
|
|||||||
for (auto& arr : outputs) {
|
for (auto& arr : outputs) {
|
||||||
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
|
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
|
||||||
recurse(arr);
|
recurse(arr);
|
||||||
// Insert a dependency for every output to synchronize
|
|
||||||
// with at the end.
|
|
||||||
if (!arr.is_evaled() && deps.find(arr.primitive_id()) == deps.end()) {
|
|
||||||
deps.insert({arr.primitive_id(), std::shared_future<void>{}});
|
|
||||||
output_primitive_ids.push_back(arr.primitive_id());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Insert output dependencies
|
||||||
|
for (auto pid : output_primitives) {
|
||||||
|
deps.insert({pid, std::shared_future<void>{}});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||||
@ -293,7 +305,7 @@ void eval(const std::vector<array>& outputs) {
|
|||||||
scheduler::enqueue(stream, std::move(task));
|
scheduler::enqueue(stream, std::move(task));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto id : output_primitive_ids) {
|
for (auto id : output_primitives) {
|
||||||
deps[id].wait();
|
deps[id].wait();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user