From 5b46b9bc52f07d303169e2df3bfecd3f4a7594f2 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 13 Jan 2024 19:58:20 -0800 Subject: [PATCH] Add priority-queue eval --- mlx/transforms.cpp | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 6135a54f74..0c1368afe9 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -199,10 +199,23 @@ void simplify(const std::vector& outputs) { } void eval(const std::vector& outputs) { - std::function recurse; - std::queue tape; - std::unordered_set cache; + std::function recurse; + std::unordered_map cache; std::unordered_map> deps; + struct ArrayWithPriority { + int depth; + int order; + array x; + }; + auto cmp = [](const ArrayWithPriority& a, const ArrayWithPriority& b) { + return (a.depth > b.depth) || (a.depth == b.depth && a.order > b.order); + }; + std::priority_queue< + ArrayWithPriority, + std::vector, + decltype(cmp)> + tape(cmp); + int order; // Make an effort to choose a good output stream Stream stream = default_stream(default_device()); @@ -218,11 +231,12 @@ void eval(const std::vector& outputs) { recurse = [&](const array& a) { auto id = a.id(); - if (cache.find(id) != cache.end()) { - return; + if (auto it = cache.find(id); it != cache.end()) { + return it->second; } + int input_depth = 0; for (auto in : a.inputs()) { - recurse(in); + input_depth = std::max(input_depth, recurse(in)); // If one of the inputs is being computed on a different // stream, we need to manage the dependency. if (!in.is_evaled()) { @@ -231,17 +245,18 @@ void eval(const std::vector& outputs) { } } } - cache.insert(id); + cache.insert({id, input_depth + 1}); for (auto& s : a.siblings()) { - cache.insert(s.id()); + cache.insert({s.id(), input_depth + 1}); } if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) { if (!a.has_primitive()) { throw std::invalid_argument( "[eval] Attempting to eval an array without a primitive."); } - tape.push(a); + tape.push({input_depth + 1, order++, a}); } + return input_depth + 1; }; recurse(synchronizer); @@ -250,7 +265,8 @@ void eval(const std::vector& outputs) { std::vector>> ps; while (!tape.empty()) { - auto arr = std::move(tape.front()); + auto val = std::move(tape.top()); + auto arr = std::move(val.x); tape.pop(); if (arr.is_evaled()) { if (!arr.is_tracer() && arr.has_primitive()) {