diff --git a/mlx/array.cpp b/mlx/array.cpp index 83c2fe6d7..add6be279 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -104,13 +104,11 @@ void array::detach() { s.array_desc_->inputs.clear(); s.array_desc_->siblings.clear(); s.array_desc_->position = 0; - s.array_desc_->depth = 0; s.array_desc_->primitive = nullptr; } array_desc_->inputs.clear(); array_desc_->siblings.clear(); array_desc_->position = 0; - array_desc_->depth = 0; array_desc_->primitive = nullptr; } @@ -189,9 +187,7 @@ array::ArrayDesc::ArrayDesc( std::tie(size, strides) = cum_prod(this->shape); for (auto& in : this->inputs) { is_tracer |= in.is_tracer(); - depth = std::max(in.graph_depth(), depth); } - depth++; } array::ArrayDesc::ArrayDesc( @@ -206,9 +202,7 @@ array::ArrayDesc::ArrayDesc( std::tie(size, strides) = cum_prod(this->shape); for (auto& in : this->inputs) { is_tracer |= in.is_tracer(); - depth = std::max(in.graph_depth(), depth); } - depth++; } array::ArrayIterator::ArrayIterator(const array& arr, int idx) diff --git a/mlx/array.h b/mlx/array.h index f73dd66ff..740e69886 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -273,11 +273,6 @@ class array { return outputs; }; - /** The depth of the array in the graph. Evaluated arrays have depth 0. */ - uint32_t graph_depth() const { - return array_desc_->depth; - } - /** Detach the array from the graph. */ void detach(); @@ -388,9 +383,6 @@ class array { // The arrays position in the output list uint32_t position{0}; - // The depth of the array in the graph. - uint32_t depth{0}; - explicit ArrayDesc(const std::vector& shape, Dtype dtype); explicit ArrayDesc( diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 1ba403ea1..e89d50375 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -38,7 +38,7 @@ class Synchronizer : public Primitive { int detail::InTracing::tracing_counter{0}; void eval(const std::vector& outputs) { - std::function recurse; + std::function recurse; std::queue tape; std::unordered_set cache; std::unordered_map> deps; @@ -55,56 +55,32 @@ void eval(const std::vector& outputs) { auto synchronizer = array({}, bool_, std::make_unique(stream), outputs); - recurse = [&](const array& a, bool largest_branch_first) { + size_t depth_counter = 0; + recurse = [&](const array& a) { + if (depth_counter > max_graph_depth) { + throw std::runtime_error( + "[eval] Graph depth exceeded maximum allowed limit." + " Try evaluating the graph more frequently."); + } + auto id = a.id(); if (cache.find(id) != cache.end()) { return; } - // If the input is being computed on a different stream, we need to manage - // the dependency. - auto check_dependency = [&](const array& in) { + // Recurse to the largest or smallest branch first. + depth_counter++; + for (auto& in : a.inputs()) { + recurse(in); if (!in.is_evaled()) { + // If the input is being computed on a different stream, we need to + // manage the dependency. if (a.primitive().stream() != in.primitive().stream()) { deps.insert({in.primitive_id(), std::shared_future{}}); } } - }; - - // Recurse to the largest or smallest branch first. - size_t num_inputs = a.inputs().size(); - if (num_inputs == 1) { - auto& in = a.inputs()[0]; - recurse(in, true); - check_dependency(in); - } else if (num_inputs == 2) { - auto depth_1 = a.inputs()[0].graph_depth(); - auto depth_2 = a.inputs()[1].graph_depth(); - auto& in1 = a.inputs()[static_cast( - !((depth_1 > depth_2) == largest_branch_first))]; - auto& in2 = a.inputs()[static_cast( - ((depth_1 > depth_2) == largest_branch_first))]; - recurse(in1, true); - check_dependency(in1); - recurse(in2, true); - check_dependency(in2); - } else if (num_inputs > 2) { - std::vector recursion_order(a.inputs().size()); - std::iota(recursion_order.begin(), recursion_order.end(), 0); - std::sort( - recursion_order.begin(), - recursion_order.end(), - [&a, largest_branch_first](int i, int j) { - auto depth_i = a.inputs()[i].graph_depth(); - auto depth_j = a.inputs()[j].graph_depth(); - return largest_branch_first ? depth_i > depth_j : depth_j < depth_i; - }); - for (int idx : recursion_order) { - auto& in = a.inputs()[idx]; - recurse(in, true); - check_dependency(in); - } } + depth_counter--; cache.insert(id); for (auto& s : a.siblings()) { @@ -119,12 +95,7 @@ void eval(const std::vector& outputs) { } }; - if (synchronizer.graph_depth() > max_graph_depth) { - throw std::runtime_error( - "[eval] Graph depth exceeded maximum allowed limit." - " Try evaluating the graph more frequently."); - } - recurse(synchronizer, false); + recurse(synchronizer); uintptr_t synch_id = synchronizer.primitive_id(); deps.insert({synch_id, std::shared_future{}});