mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-22 17:41:20 +08:00
parent
0eb56d5be0
commit
8e88e30d95
@ -40,7 +40,7 @@ int detail::InTracing::tracing_counter{0};
|
|||||||
int detail::RetainGraph::tracing_counter{0};
|
int detail::RetainGraph::tracing_counter{0};
|
||||||
|
|
||||||
array eval_impl(std::vector<array> outputs, bool async) {
|
array eval_impl(std::vector<array> outputs, bool async) {
|
||||||
std::queue<array> tape;
|
std::vector<array> tape;
|
||||||
|
|
||||||
// stream events to use for synchronization
|
// stream events to use for synchronization
|
||||||
std::unordered_map<uint32_t, Event> events;
|
std::unordered_map<uint32_t, Event> events;
|
||||||
@ -64,7 +64,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
events.emplace(stream.index, Event{stream});
|
events.emplace(stream.index, Event{stream});
|
||||||
|
|
||||||
{
|
{
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
// Record the degree of each input
|
||||||
|
std::unordered_map<std::uintptr_t, int> cache;
|
||||||
|
|
||||||
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
|
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
|
||||||
dfs.emplace(synchronizer, 0);
|
dfs.emplace(synchronizer, 0);
|
||||||
while (!dfs.empty()) {
|
while (!dfs.empty()) {
|
||||||
@ -104,42 +106,75 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cache.find(in.id()) == cache.end()) {
|
// All siblings have the same degree
|
||||||
|
auto cache_it = cache.find(in.id());
|
||||||
|
if (cache_it == cache.end()) {
|
||||||
dfs.emplace(in, 0);
|
dfs.emplace(in, 0);
|
||||||
cache.insert(in.id());
|
cache.insert({in.id(), 1});
|
||||||
for (auto& s : in.siblings()) {
|
for (auto& s : in.siblings()) {
|
||||||
cache.insert(s.id());
|
cache.insert({s.id(), 1});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cache_it->second++;
|
||||||
|
for (auto& s : in.siblings()) {
|
||||||
|
cache[s.id()]++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// All inputs are done being processed, process this array
|
|
||||||
if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
|
if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
|
||||||
a.has_primitive()) {
|
a.has_primitive()) {
|
||||||
// If the array is evaluated and is no longer a tracer, detach it
|
// If the array is evaluated and is no longer a tracer, detach it
|
||||||
a.detach();
|
a.detach();
|
||||||
} else if (a.status() == array::Status::unscheduled) {
|
|
||||||
tape.push(a);
|
|
||||||
// Lookup corresponding event and increment counter
|
|
||||||
auto& stream = a.primitive().stream();
|
|
||||||
auto e = events.find(stream.index);
|
|
||||||
if (e == events.end()) {
|
|
||||||
e = events.emplace(stream.index, Event{stream}).first;
|
|
||||||
}
|
|
||||||
e->second.set_value(e->second.value() + 1);
|
|
||||||
a.attach_event(e->second);
|
|
||||||
for (auto& s : a.siblings()) {
|
|
||||||
s.attach_event(e->second);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
dfs.pop();
|
dfs.pop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build the tape in BFS order
|
||||||
|
tape.push_back(synchronizer);
|
||||||
|
for (int i = 0; !cache.empty() && i < tape.size(); ++i) {
|
||||||
|
auto& a = tape[i];
|
||||||
|
for (auto& in : a.inputs()) {
|
||||||
|
if (in.status() != array::Status::unscheduled) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto it = cache.find(in.id());
|
||||||
|
it->second -= 1;
|
||||||
|
|
||||||
|
if (it->second != 0) {
|
||||||
|
for (auto& s : in.siblings()) {
|
||||||
|
cache[s.id()] -= 1;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove input and siblings from cache
|
||||||
|
cache.erase(it);
|
||||||
|
for (auto& s : in.siblings()) {
|
||||||
|
cache.erase(s.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
tape.push_back(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
while (!tape.empty()) {
|
while (!tape.empty()) {
|
||||||
auto arr = std::move(tape.front());
|
auto arr = std::move(tape.back());
|
||||||
tape.pop();
|
tape.pop_back();
|
||||||
|
|
||||||
|
auto stream = arr.primitive().stream();
|
||||||
|
|
||||||
|
// Lookup corresponding event and increment counter
|
||||||
|
auto e = events.find(stream.index);
|
||||||
|
if (e == events.end()) {
|
||||||
|
e = events.emplace(stream.index, Event{stream}).first;
|
||||||
|
}
|
||||||
|
e->second.set_value(e->second.value() + 1);
|
||||||
|
arr.attach_event(e->second);
|
||||||
|
for (auto& s : arr.siblings()) {
|
||||||
|
s.attach_event(e->second);
|
||||||
|
}
|
||||||
|
|
||||||
// Set the status of the array and siblings.
|
// Set the status of the array and siblings.
|
||||||
arr.set_status(array::Status::scheduled);
|
arr.set_status(array::Status::scheduled);
|
||||||
@ -147,7 +182,6 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
s.set_status(array::Status::scheduled);
|
s.set_status(array::Status::scheduled);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto stream = arr.primitive().stream();
|
|
||||||
std::vector<std::shared_future<void>> arr_deps;
|
std::vector<std::shared_future<void>> arr_deps;
|
||||||
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
|
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user