mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-12 12:16:43 +08:00
Add priority-queue eval
This commit is contained in:
parent
fd94be28ea
commit
5b46b9bc52
@ -199,10 +199,23 @@ void simplify(const std::vector<array>& outputs) {
|
||||
}
|
||||
|
||||
void eval(const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::function<int(const array&)> recurse;
|
||||
std::unordered_map<std::uintptr_t, int> cache;
|
||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||
struct ArrayWithPriority {
|
||||
int depth;
|
||||
int order;
|
||||
array x;
|
||||
};
|
||||
auto cmp = [](const ArrayWithPriority& a, const ArrayWithPriority& b) {
|
||||
return (a.depth > b.depth) || (a.depth == b.depth && a.order > b.order);
|
||||
};
|
||||
std::priority_queue<
|
||||
ArrayWithPriority,
|
||||
std::vector<ArrayWithPriority>,
|
||||
decltype(cmp)>
|
||||
tape(cmp);
|
||||
int order;
|
||||
|
||||
// Make an effort to choose a good output stream
|
||||
Stream stream = default_stream(default_device());
|
||||
@ -218,11 +231,12 @@ void eval(const std::vector<array>& outputs) {
|
||||
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
if (auto it = cache.find(id); it != cache.end()) {
|
||||
return it->second;
|
||||
}
|
||||
int input_depth = 0;
|
||||
for (auto in : a.inputs()) {
|
||||
recurse(in);
|
||||
input_depth = std::max(input_depth, recurse(in));
|
||||
// If one of the inputs is being computed on a different
|
||||
// stream, we need to manage the dependency.
|
||||
if (!in.is_evaled()) {
|
||||
@ -231,17 +245,18 @@ void eval(const std::vector<array>& outputs) {
|
||||
}
|
||||
}
|
||||
}
|
||||
cache.insert(id);
|
||||
cache.insert({id, input_depth + 1});
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
cache.insert({s.id(), input_depth + 1});
|
||||
}
|
||||
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
||||
if (!a.has_primitive()) {
|
||||
throw std::invalid_argument(
|
||||
"[eval] Attempting to eval an array without a primitive.");
|
||||
}
|
||||
tape.push(a);
|
||||
tape.push({input_depth + 1, order++, a});
|
||||
}
|
||||
return input_depth + 1;
|
||||
};
|
||||
|
||||
recurse(synchronizer);
|
||||
@ -250,7 +265,8 @@ void eval(const std::vector<array>& outputs) {
|
||||
|
||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.front());
|
||||
auto val = std::move(tape.top());
|
||||
auto arr = std::move(val.x);
|
||||
tape.pop();
|
||||
if (arr.is_evaled()) {
|
||||
if (!arr.is_tracer() && arr.has_primitive()) {
|
||||
|
Loading…
Reference in New Issue
Block a user