mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Remove depth traversal (#813)
* no depth traversal * counter outside loop
This commit is contained in:
parent
28301807c2
commit
a4d290adb9
@ -104,13 +104,11 @@ void array::detach() {
|
|||||||
s.array_desc_->inputs.clear();
|
s.array_desc_->inputs.clear();
|
||||||
s.array_desc_->siblings.clear();
|
s.array_desc_->siblings.clear();
|
||||||
s.array_desc_->position = 0;
|
s.array_desc_->position = 0;
|
||||||
s.array_desc_->depth = 0;
|
|
||||||
s.array_desc_->primitive = nullptr;
|
s.array_desc_->primitive = nullptr;
|
||||||
}
|
}
|
||||||
array_desc_->inputs.clear();
|
array_desc_->inputs.clear();
|
||||||
array_desc_->siblings.clear();
|
array_desc_->siblings.clear();
|
||||||
array_desc_->position = 0;
|
array_desc_->position = 0;
|
||||||
array_desc_->depth = 0;
|
|
||||||
array_desc_->primitive = nullptr;
|
array_desc_->primitive = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,9 +187,7 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
std::tie(size, strides) = cum_prod(this->shape);
|
std::tie(size, strides) = cum_prod(this->shape);
|
||||||
for (auto& in : this->inputs) {
|
for (auto& in : this->inputs) {
|
||||||
is_tracer |= in.is_tracer();
|
is_tracer |= in.is_tracer();
|
||||||
depth = std::max(in.graph_depth(), depth);
|
|
||||||
}
|
}
|
||||||
depth++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(
|
array::ArrayDesc::ArrayDesc(
|
||||||
@ -206,9 +202,7 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
std::tie(size, strides) = cum_prod(this->shape);
|
std::tie(size, strides) = cum_prod(this->shape);
|
||||||
for (auto& in : this->inputs) {
|
for (auto& in : this->inputs) {
|
||||||
is_tracer |= in.is_tracer();
|
is_tracer |= in.is_tracer();
|
||||||
depth = std::max(in.graph_depth(), depth);
|
|
||||||
}
|
}
|
||||||
depth++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||||
|
@ -273,11 +273,6 @@ class array {
|
|||||||
return outputs;
|
return outputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
|
|
||||||
uint32_t graph_depth() const {
|
|
||||||
return array_desc_->depth;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Detach the array from the graph. */
|
/** Detach the array from the graph. */
|
||||||
void detach();
|
void detach();
|
||||||
|
|
||||||
@ -388,9 +383,6 @@ class array {
|
|||||||
// The arrays position in the output list
|
// The arrays position in the output list
|
||||||
uint32_t position{0};
|
uint32_t position{0};
|
||||||
|
|
||||||
// The depth of the array in the graph.
|
|
||||||
uint32_t depth{0};
|
|
||||||
|
|
||||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||||
|
|
||||||
explicit ArrayDesc(
|
explicit ArrayDesc(
|
||||||
|
@ -38,7 +38,7 @@ class Synchronizer : public Primitive {
|
|||||||
int detail::InTracing::tracing_counter{0};
|
int detail::InTracing::tracing_counter{0};
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs) {
|
void eval(const std::vector<array>& outputs) {
|
||||||
std::function<void(const array&, bool)> recurse;
|
std::function<void(const array&)> recurse;
|
||||||
std::queue<array> tape;
|
std::queue<array> tape;
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||||
@ -55,56 +55,32 @@ void eval(const std::vector<array>& outputs) {
|
|||||||
auto synchronizer =
|
auto synchronizer =
|
||||||
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
|
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
|
||||||
|
|
||||||
recurse = [&](const array& a, bool largest_branch_first) {
|
size_t depth_counter = 0;
|
||||||
|
recurse = [&](const array& a) {
|
||||||
|
if (depth_counter > max_graph_depth) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[eval] Graph depth exceeded maximum allowed limit."
|
||||||
|
" Try evaluating the graph more frequently.");
|
||||||
|
}
|
||||||
|
|
||||||
auto id = a.id();
|
auto id = a.id();
|
||||||
if (cache.find(id) != cache.end()) {
|
if (cache.find(id) != cache.end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the input is being computed on a different stream, we need to manage
|
// Recurse to the largest or smallest branch first.
|
||||||
// the dependency.
|
depth_counter++;
|
||||||
auto check_dependency = [&](const array& in) {
|
for (auto& in : a.inputs()) {
|
||||||
|
recurse(in);
|
||||||
if (!in.is_evaled()) {
|
if (!in.is_evaled()) {
|
||||||
|
// If the input is being computed on a different stream, we need to
|
||||||
|
// manage the dependency.
|
||||||
if (a.primitive().stream() != in.primitive().stream()) {
|
if (a.primitive().stream() != in.primitive().stream()) {
|
||||||
deps.insert({in.primitive_id(), std::shared_future<void>{}});
|
deps.insert({in.primitive_id(), std::shared_future<void>{}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
// Recurse to the largest or smallest branch first.
|
|
||||||
size_t num_inputs = a.inputs().size();
|
|
||||||
if (num_inputs == 1) {
|
|
||||||
auto& in = a.inputs()[0];
|
|
||||||
recurse(in, true);
|
|
||||||
check_dependency(in);
|
|
||||||
} else if (num_inputs == 2) {
|
|
||||||
auto depth_1 = a.inputs()[0].graph_depth();
|
|
||||||
auto depth_2 = a.inputs()[1].graph_depth();
|
|
||||||
auto& in1 = a.inputs()[static_cast<int>(
|
|
||||||
!((depth_1 > depth_2) == largest_branch_first))];
|
|
||||||
auto& in2 = a.inputs()[static_cast<int>(
|
|
||||||
((depth_1 > depth_2) == largest_branch_first))];
|
|
||||||
recurse(in1, true);
|
|
||||||
check_dependency(in1);
|
|
||||||
recurse(in2, true);
|
|
||||||
check_dependency(in2);
|
|
||||||
} else if (num_inputs > 2) {
|
|
||||||
std::vector<int> recursion_order(a.inputs().size());
|
|
||||||
std::iota(recursion_order.begin(), recursion_order.end(), 0);
|
|
||||||
std::sort(
|
|
||||||
recursion_order.begin(),
|
|
||||||
recursion_order.end(),
|
|
||||||
[&a, largest_branch_first](int i, int j) {
|
|
||||||
auto depth_i = a.inputs()[i].graph_depth();
|
|
||||||
auto depth_j = a.inputs()[j].graph_depth();
|
|
||||||
return largest_branch_first ? depth_i > depth_j : depth_j < depth_i;
|
|
||||||
});
|
|
||||||
for (int idx : recursion_order) {
|
|
||||||
auto& in = a.inputs()[idx];
|
|
||||||
recurse(in, true);
|
|
||||||
check_dependency(in);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
depth_counter--;
|
||||||
|
|
||||||
cache.insert(id);
|
cache.insert(id);
|
||||||
for (auto& s : a.siblings()) {
|
for (auto& s : a.siblings()) {
|
||||||
@ -119,12 +95,7 @@ void eval(const std::vector<array>& outputs) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (synchronizer.graph_depth() > max_graph_depth) {
|
recurse(synchronizer);
|
||||||
throw std::runtime_error(
|
|
||||||
"[eval] Graph depth exceeded maximum allowed limit."
|
|
||||||
" Try evaluating the graph more frequently.");
|
|
||||||
}
|
|
||||||
recurse(synchronizer, false);
|
|
||||||
uintptr_t synch_id = synchronizer.primitive_id();
|
uintptr_t synch_id = synchronizer.primitive_id();
|
||||||
deps.insert({synch_id, std::shared_future<void>{}});
|
deps.insert({synch_id, std::shared_future<void>{}});
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user