Optionally specify names for arrays when exporting (#1749)

This commit is contained in:
Angelos Katharopoulos 2025-01-06 13:07:46 -08:00 committed by GitHub
parent 058d6ce683
commit 25b3a3e541
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 29 deletions

View File

@ -30,6 +30,10 @@ const std::string& NodeNamer::get_name(const array& x) {
return it->second; return it->second;
} }
void NodeNamer::set_name(const array& x, std::string n) {
names[x.id()] = std::move(n);
}
void depth_first_traversal( void depth_first_traversal(
std::function<void(array)> callback, std::function<void(array)> callback,
const std::vector<array>& outputs) { const std::vector<array>& outputs) {
@ -55,7 +59,10 @@ void depth_first_traversal(
} }
} }
void print_graph(std::ostream& os, const std::vector<array>& outputs) { void print_graph(
std::ostream& os,
NodeNamer namer,
const std::vector<array>& outputs) {
std::vector<array> tape; std::vector<array> tape;
std::vector<array> inputs; std::vector<array> inputs;
@ -69,7 +76,6 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
}, },
outputs); outputs);
NodeNamer namer;
auto print_arrs = [&namer, &os](std::vector<array> arrs) { auto print_arrs = [&namer, &os](std::vector<array> arrs) {
for (auto& arr : arrs) { for (auto& arr : arrs) {
os << namer.get_name(arr); os << namer.get_name(arr);
@ -96,20 +102,39 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
} }
} }
void export_to_dot(std::ostream& os, const std::vector<array>& outputs) { void export_to_dot(
std::ostream& os,
NodeNamer namer,
const std::vector<array>& nodes) {
// Perform one DFS to mark arrays as intermediate if they are used as inputs
// to other arrays.
std::unordered_set<std::uintptr_t> intermediate_set;
depth_first_traversal(
[&](const array& x) {
// No primitive so it is an input
if (!x.has_primitive()) {
return;
}
for (auto& a : x.inputs()) {
intermediate_set.insert(a.id());
}
},
nodes);
// Now we got everything we need to make the graph. Arrays can be one of 3
// things:
// 1. Inputs, when they have no primitive ie are evaluated
// 2. Intermediates, when they are the intermediate set
// 3. Outputs, if they are not inputs and not intermediates
os << "digraph {" << std::endl; os << "digraph {" << std::endl;
std::unordered_set<std::uintptr_t> output_set;
for (auto& o : outputs) {
output_set.insert(o.id());
}
std::unordered_set<std::uintptr_t> input_set;
NodeNamer namer;
depth_first_traversal( depth_first_traversal(
[&](const array& x) { [&](const array& x) {
if (!x.has_primitive()) { if (!x.has_primitive()) {
input_set.insert(x.id()); os << "{ rank=source; \"" << namer.get_name(x) << "\"; }"
os << "{ rank=source; " << namer.get_name(x) << "; }" << std::endl; << std::endl;
return; return;
} }
@ -123,24 +148,26 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
os << "; }" << std::endl; os << "; }" << std::endl;
// Arrows to primitive's inputs // Arrows to primitive's inputs
for (auto& a : x.inputs()) { for (auto& a : x.inputs()) {
os << namer.get_name(a) << " -> " << x.primitive_id() << std::endl; os << '"' << namer.get_name(a) << "\" -> " << x.primitive_id()
<< std::endl;
} }
} }
// Point outputs to their primitive // Point outputs to their primitive
for (auto& a : x.outputs()) { for (auto& a : x.outputs()) {
os << "{ "; os << "{ ";
if (output_set.find(a.id()) != output_set.end()) { if (intermediate_set.find(a.id()) == intermediate_set.end()) {
os << "rank=sink; "; os << "rank=sink; ";
} }
os << namer.get_name(a); os << '"' << namer.get_name(a);
os << "; }" << std::endl; os << "\"; }" << std::endl;
if (x.has_primitive()) { if (x.has_primitive()) {
os << x.primitive_id() << " -> " << namer.get_name(a) << std::endl; os << x.primitive_id() << " -> \"" << namer.get_name(a) << '"'
<< std::endl;
} }
} }
}, },
outputs); nodes);
os << "}"; os << "}";
} }

View File

@ -12,20 +12,55 @@ struct NodeNamer {
std::unordered_map<std::uintptr_t, std::string> names; std::unordered_map<std::uintptr_t, std::string> names;
const std::string& get_name(const array& x); const std::string& get_name(const array& x);
void set_name(const array& x, std::string n);
}; };
void print_graph(std::ostream& os, const std::vector<array>& outputs); void print_graph(
std::ostream& os,
NodeNamer namer,
const std::vector<array>& outputs);
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>> inline void print_graph(std::ostream& os, const std::vector<array>& outputs) {
void print_graph(std::ostream& os, Arrays&&... outputs) { print_graph(os, NodeNamer{}, outputs);
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
} }
void export_to_dot(std::ostream& os, const std::vector<array>& outputs); template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
inline void print_graph(std::ostream& os, Arrays&&... outputs) {
print_graph(
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>> template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void export_to_dot(std::ostream& os, Arrays&&... outputs) { inline void
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...}); print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
print_graph(
os,
std::move(namer),
std::vector<array>{std::forward<Arrays>(outputs)...});
}
void export_to_dot(
std::ostream& os,
NodeNamer namer,
const std::vector<array>& outputs);
inline void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
export_to_dot(os, NodeNamer{}, outputs);
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
inline void export_to_dot(std::ostream& os, Arrays&&... outputs) {
export_to_dot(
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
inline void
export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
export_to_dot(
os,
std::move(namer),
std::vector<array>{std::forward<Arrays>(outputs)...});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -241,14 +241,20 @@ void init_export(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"export_to_dot", "export_to_dot",
[](nb::object file, const nb::args& args) { [](nb::object file, const nb::args& args, const nb::kwargs& kwargs) {
std::vector<mx::array> arrays = tree_flatten(args); std::vector<mx::array> arrays =
tree_flatten(nb::make_tuple(args, kwargs));
mx::NodeNamer namer;
for (const auto& n : kwargs) {
namer.set_name(
nb::cast<mx::array>(n.second), nb::cast<std::string>(n.first));
}
if (nb::isinstance<nb::str>(file)) { if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file)); std::ofstream out(nb::cast<std::string>(file));
mx::export_to_dot(out, arrays); mx::export_to_dot(out, std::move(namer), arrays);
} else if (nb::hasattr(file, "write")) { } else if (nb::hasattr(file, "write")) {
std::ostringstream out; std::ostringstream out;
mx::export_to_dot(out, arrays); mx::export_to_dot(out, std::move(namer), arrays);
auto write = file.attr("write"); auto write = file.attr("write");
write(out.str()); write(out.str());
} else { } else {
@ -259,19 +265,25 @@ void init_export(nb::module_& m) {
}, },
"file"_a, "file"_a,
"args"_a, "args"_a,
"kwargs"_a,
R"pbdoc( R"pbdoc(
Export a graph to DOT format for visualization. Export a graph to DOT format for visualization.
A variable number of output arrays can be provided for exporting A variable number of output arrays can be provided for exporting
The graph exported will recursively include all enevaluated inputs of The graph exported will recursively include all unevaluated inputs of
the provided outputs. the provided outputs.
Args: Args:
file (str): The file path to export to. file (str): The file path to export to.
*args (array): The output arrays. *args (array): The output arrays.
**kwargs (dict[str, array]): Provide some names for arrays in the
graph to make the result easier to parse.
Example: Example:
>>> a = mx.array(1) + mx.array(2) >>> a = mx.array(1) + mx.array(2)
>>> mx.export_to_dot("graph.dot", a) >>> mx.export_to_dot("graph.dot", a)
>>> x = mx.array(1)
>>> y = mx.array(2)
>>> mx.export_to_dot("graph.dot", x + y, x=x, y=y)
)pbdoc"); )pbdoc");
} }