Optionally specify names for arrays when exporting (#1749)

This commit is contained in:
Angelos Katharopoulos
2025-01-06 13:07:46 -08:00
committed by GitHub
parent 058d6ce683
commit 25b3a3e541
3 changed files with 103 additions and 29 deletions

View File

@@ -30,6 +30,10 @@ const std::string& NodeNamer::get_name(const array& x) {
return it->second;
}
void NodeNamer::set_name(const array& x, std::string n) {
names[x.id()] = std::move(n);
}
void depth_first_traversal(
std::function<void(array)> callback,
const std::vector<array>& outputs) {
@@ -55,7 +59,10 @@ void depth_first_traversal(
}
}
void print_graph(std::ostream& os, const std::vector<array>& outputs) {
void print_graph(
std::ostream& os,
NodeNamer namer,
const std::vector<array>& outputs) {
std::vector<array> tape;
std::vector<array> inputs;
@@ -69,7 +76,6 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
},
outputs);
NodeNamer namer;
auto print_arrs = [&namer, &os](std::vector<array> arrs) {
for (auto& arr : arrs) {
os << namer.get_name(arr);
@@ -96,20 +102,39 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
}
}
void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
void export_to_dot(
std::ostream& os,
NodeNamer namer,
const std::vector<array>& nodes) {
// Perform one DFS to mark arrays as intermediate if they are used as inputs
// to other arrays.
std::unordered_set<std::uintptr_t> intermediate_set;
depth_first_traversal(
[&](const array& x) {
// No primitive so it is an input
if (!x.has_primitive()) {
return;
}
for (auto& a : x.inputs()) {
intermediate_set.insert(a.id());
}
},
nodes);
// Now we got everything we need to make the graph. Arrays can be one of 3
// things:
// 1. Inputs, when they have no primitive ie are evaluated
// 2. Intermediates, when they are the intermediate set
// 3. Outputs, if they are not inputs and not intermediates
os << "digraph {" << std::endl;
std::unordered_set<std::uintptr_t> output_set;
for (auto& o : outputs) {
output_set.insert(o.id());
}
std::unordered_set<std::uintptr_t> input_set;
NodeNamer namer;
depth_first_traversal(
[&](const array& x) {
if (!x.has_primitive()) {
input_set.insert(x.id());
os << "{ rank=source; " << namer.get_name(x) << "; }" << std::endl;
os << "{ rank=source; \"" << namer.get_name(x) << "\"; }"
<< std::endl;
return;
}
@@ -123,24 +148,26 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
os << "; }" << std::endl;
// Arrows to primitive's inputs
for (auto& a : x.inputs()) {
os << namer.get_name(a) << " -> " << x.primitive_id() << std::endl;
os << '"' << namer.get_name(a) << "\" -> " << x.primitive_id()
<< std::endl;
}
}
// Point outputs to their primitive
for (auto& a : x.outputs()) {
os << "{ ";
if (output_set.find(a.id()) != output_set.end()) {
if (intermediate_set.find(a.id()) == intermediate_set.end()) {
os << "rank=sink; ";
}
os << namer.get_name(a);
os << "; }" << std::endl;
os << '"' << namer.get_name(a);
os << "\"; }" << std::endl;
if (x.has_primitive()) {
os << x.primitive_id() << " -> " << namer.get_name(a) << std::endl;
os << x.primitive_id() << " -> \"" << namer.get_name(a) << '"'
<< std::endl;
}
}
},
outputs);
nodes);
os << "}";
}

View File

@@ -12,20 +12,55 @@ struct NodeNamer {
std::unordered_map<std::uintptr_t, std::string> names;
const std::string& get_name(const array& x);
void set_name(const array& x, std::string n);
};
void print_graph(std::ostream& os, const std::vector<array>& outputs);
void print_graph(
std::ostream& os,
NodeNamer namer,
const std::vector<array>& outputs);
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void print_graph(std::ostream& os, Arrays&&... outputs) {
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
inline void print_graph(std::ostream& os, const std::vector<array>& outputs) {
print_graph(os, NodeNamer{}, outputs);
}
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
inline void print_graph(std::ostream& os, Arrays&&... outputs) {
print_graph(
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void export_to_dot(std::ostream& os, Arrays&&... outputs) {
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
inline void
print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
print_graph(
os,
std::move(namer),
std::vector<array>{std::forward<Arrays>(outputs)...});
}
void export_to_dot(
std::ostream& os,
NodeNamer namer,
const std::vector<array>& outputs);
inline void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
export_to_dot(os, NodeNamer{}, outputs);
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
inline void export_to_dot(std::ostream& os, Arrays&&... outputs) {
export_to_dot(
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
inline void
export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
export_to_dot(
os,
std::move(namer),
std::vector<array>{std::forward<Arrays>(outputs)...});
}
} // namespace mlx::core