mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
parent
0eb56d5be0
commit
8e88e30d95
@ -40,7 +40,7 @@ int detail::InTracing::tracing_counter{0};
|
||||
int detail::RetainGraph::tracing_counter{0};
|
||||
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
std::queue<array> tape;
|
||||
std::vector<array> tape;
|
||||
|
||||
// stream events to use for synchronization
|
||||
std::unordered_map<uint32_t, Event> events;
|
||||
@ -64,7 +64,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
events.emplace(stream.index, Event{stream});
|
||||
|
||||
{
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
// Record the degree of each input
|
||||
std::unordered_map<std::uintptr_t, int> cache;
|
||||
|
||||
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
|
||||
dfs.emplace(synchronizer, 0);
|
||||
while (!dfs.empty()) {
|
||||
@ -104,42 +106,75 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
}
|
||||
|
||||
if (cache.find(in.id()) == cache.end()) {
|
||||
// All siblings have the same degree
|
||||
auto cache_it = cache.find(in.id());
|
||||
if (cache_it == cache.end()) {
|
||||
dfs.emplace(in, 0);
|
||||
cache.insert(in.id());
|
||||
cache.insert({in.id(), 1});
|
||||
for (auto& s : in.siblings()) {
|
||||
cache.insert(s.id());
|
||||
cache.insert({s.id(), 1});
|
||||
}
|
||||
} else {
|
||||
cache_it->second++;
|
||||
for (auto& s : in.siblings()) {
|
||||
cache[s.id()]++;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// All inputs are done being processed, process this array
|
||||
if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
|
||||
a.has_primitive()) {
|
||||
// If the array is evaluated and is no longer a tracer, detach it
|
||||
a.detach();
|
||||
} else if (a.status() == array::Status::unscheduled) {
|
||||
tape.push(a);
|
||||
// Lookup corresponding event and increment counter
|
||||
auto& stream = a.primitive().stream();
|
||||
auto e = events.find(stream.index);
|
||||
if (e == events.end()) {
|
||||
e = events.emplace(stream.index, Event{stream}).first;
|
||||
}
|
||||
e->second.set_value(e->second.value() + 1);
|
||||
a.attach_event(e->second);
|
||||
for (auto& s : a.siblings()) {
|
||||
s.attach_event(e->second);
|
||||
}
|
||||
}
|
||||
dfs.pop();
|
||||
}
|
||||
|
||||
// Build the tape in BFS order
|
||||
tape.push_back(synchronizer);
|
||||
for (int i = 0; !cache.empty() && i < tape.size(); ++i) {
|
||||
auto& a = tape[i];
|
||||
for (auto& in : a.inputs()) {
|
||||
if (in.status() != array::Status::unscheduled) {
|
||||
continue;
|
||||
}
|
||||
auto it = cache.find(in.id());
|
||||
it->second -= 1;
|
||||
|
||||
if (it->second != 0) {
|
||||
for (auto& s : in.siblings()) {
|
||||
cache[s.id()] -= 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Remove input and siblings from cache
|
||||
cache.erase(it);
|
||||
for (auto& s : in.siblings()) {
|
||||
cache.erase(s.id());
|
||||
}
|
||||
|
||||
tape.push_back(in);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.front());
|
||||
tape.pop();
|
||||
auto arr = std::move(tape.back());
|
||||
tape.pop_back();
|
||||
|
||||
auto stream = arr.primitive().stream();
|
||||
|
||||
// Lookup corresponding event and increment counter
|
||||
auto e = events.find(stream.index);
|
||||
if (e == events.end()) {
|
||||
e = events.emplace(stream.index, Event{stream}).first;
|
||||
}
|
||||
e->second.set_value(e->second.value() + 1);
|
||||
arr.attach_event(e->second);
|
||||
for (auto& s : arr.siblings()) {
|
||||
s.attach_event(e->second);
|
||||
}
|
||||
|
||||
// Set the status of the array and siblings.
|
||||
arr.set_status(array::Status::scheduled);
|
||||
@ -147,7 +182,6 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
s.set_status(array::Status::scheduled);
|
||||
}
|
||||
|
||||
auto stream = arr.primitive().stream();
|
||||
std::vector<std::shared_future<void>> arr_deps;
|
||||
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user