Optionally specify names for arrays when exporting (#1749)

This commit is contained in:
Angelos Katharopoulos
2025-01-06 13:07:46 -08:00
committed by GitHub
parent 058d6ce683
commit 25b3a3e541
3 changed files with 103 additions and 29 deletions

View File

@@ -241,14 +241,20 @@ void init_export(nb::module_& m) {
)pbdoc");
m.def(
"export_to_dot",
[](nb::object file, const nb::args& args) {
std::vector<mx::array> arrays = tree_flatten(args);
[](nb::object file, const nb::args& args, const nb::kwargs& kwargs) {
std::vector<mx::array> arrays =
tree_flatten(nb::make_tuple(args, kwargs));
mx::NodeNamer namer;
for (const auto& n : kwargs) {
namer.set_name(
nb::cast<mx::array>(n.second), nb::cast<std::string>(n.first));
}
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
mx::export_to_dot(out, arrays);
mx::export_to_dot(out, std::move(namer), arrays);
} else if (nb::hasattr(file, "write")) {
std::ostringstream out;
mx::export_to_dot(out, arrays);
mx::export_to_dot(out, std::move(namer), arrays);
auto write = file.attr("write");
write(out.str());
} else {
@@ -259,19 +265,25 @@ void init_export(nb::module_& m) {
},
"file"_a,
"args"_a,
"kwargs"_a,
R"pbdoc(
Export a graph to DOT format for visualization.
A variable number of output arrays can be provided for exporting
The graph exported will recursively include all enevaluated inputs of
The graph exported will recursively include all unevaluated inputs of
the provided outputs.
Args:
file (str): The file path to export to.
*args (array): The output arrays.
**kwargs (dict[str, array]): Provide some names for arrays in the
graph to make the result easier to parse.
Example:
>>> a = mx.array(1) + mx.array(2)
>>> mx.export_to_dot("graph.dot", a)
>>> x = mx.array(1)
>>> y = mx.array(2)
>>> mx.export_to_dot("graph.dot", x + y, x=x, y=y)
)pbdoc");
}