diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index bb39492a8b..a19bb8c30e 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -17,9 +18,6 @@ namespace mlx::core { -// Maximum allowed graph depth for eval -constexpr uint32_t max_graph_depth = 100'000; - /* This class is only meant to be used in eval * for synchronizing with the main thread. */ class Synchronizer : public Primitive { @@ -44,8 +42,6 @@ std::shared_future async_eval(std::vector outputs) { if (global_synchronizer.valid()) { global_synchronizer.wait(); } - - std::function recurse; std::queue tape; std::unordered_set cache; std::unordered_map> deps; @@ -62,47 +58,45 @@ std::shared_future async_eval(std::vector outputs) { auto synchronizer = array( {}, bool_, std::make_shared(stream), std::move(outputs)); - size_t depth_counter = 0; - recurse = [&](const array& a) { - if (depth_counter > max_graph_depth) { - throw std::runtime_error( - "[eval] Graph depth exceeded maximum allowed limit." - " Try evaluating the graph more frequently."); - } + { + std::stack, int>> dfs; + dfs.emplace(synchronizer, 0); + while (!dfs.empty()) { + auto& [a_ref, idx] = dfs.top(); + auto& a = a_ref.get(); + if (idx < a.inputs().size()) { + // Add an input, and continue + auto& in = a.inputs()[idx++]; + if (!in.is_evaled()) { + if (!in.has_primitive()) { + throw std::invalid_argument( + "[eval] Attempting to eval an array without a primitive."); + } - auto id = a.id(); - if (cache.find(id) != cache.end()) { - return; - } - - // Recurse to the largest or smallest branch first. - depth_counter++; - for (auto& in : a.inputs()) { - recurse(in); - if (!in.is_evaled()) { - // If the input is being computed on a different stream, we need to - // manage the dependency. - if (a.primitive().stream() != in.primitive().stream()) { - deps.insert({in.output(0).id(), std::shared_future{}}); + // If the input is being computed on a different stream, we need to + // manage the dependency. + if (a.primitive().stream() != in.primitive().stream()) { + deps.insert({in.output(0).id(), std::shared_future{}}); + } } - } - } - depth_counter--; - cache.insert(id); - for (auto& s : a.siblings()) { - cache.insert(s.id()); - } - if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) { - if (!a.has_primitive()) { - throw std::invalid_argument( - "[eval] Attempting to eval an array without a primitive."); + if (cache.find(in.id()) == cache.end()) { + dfs.emplace(in, 0); + cache.insert(in.id()); + for (auto& s : in.siblings()) { + cache.insert(s.id()); + } + } + continue; } - tape.push(a); - } - }; - recurse(synchronizer); + // All inputs are done being processed, process this array + if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) { + tape.push(a); + } + dfs.pop(); + } + } deps.insert({synchronizer.id(), std::shared_future{}}); std::vector>> ps;