mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 15:24:57 +08:00
Optionally specify names for arrays when exporting (#1749)
This commit is contained in:

committed by
GitHub

parent
058d6ce683
commit
25b3a3e541
@@ -241,14 +241,20 @@ void init_export(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<mx::array> arrays = tree_flatten(args);
|
||||
[](nb::object file, const nb::args& args, const nb::kwargs& kwargs) {
|
||||
std::vector<mx::array> arrays =
|
||||
tree_flatten(nb::make_tuple(args, kwargs));
|
||||
mx::NodeNamer namer;
|
||||
for (const auto& n : kwargs) {
|
||||
namer.set_name(
|
||||
nb::cast<mx::array>(n.second), nb::cast<std::string>(n.first));
|
||||
}
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
mx::export_to_dot(out, arrays);
|
||||
mx::export_to_dot(out, std::move(namer), arrays);
|
||||
} else if (nb::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
mx::export_to_dot(out, arrays);
|
||||
mx::export_to_dot(out, std::move(namer), arrays);
|
||||
auto write = file.attr("write");
|
||||
write(out.str());
|
||||
} else {
|
||||
@@ -259,19 +265,25 @@ void init_export(nb::module_& m) {
|
||||
},
|
||||
"file"_a,
|
||||
"args"_a,
|
||||
"kwargs"_a,
|
||||
R"pbdoc(
|
||||
Export a graph to DOT format for visualization.
|
||||
|
||||
A variable number of output arrays can be provided for exporting
|
||||
The graph exported will recursively include all enevaluated inputs of
|
||||
The graph exported will recursively include all unevaluated inputs of
|
||||
the provided outputs.
|
||||
|
||||
Args:
|
||||
file (str): The file path to export to.
|
||||
*args (array): The output arrays.
|
||||
**kwargs (dict[str, array]): Provide some names for arrays in the
|
||||
graph to make the result easier to parse.
|
||||
|
||||
Example:
|
||||
>>> a = mx.array(1) + mx.array(2)
|
||||
>>> mx.export_to_dot("graph.dot", a)
|
||||
>>> x = mx.array(1)
|
||||
>>> y = mx.array(2)
|
||||
>>> mx.export_to_dot("graph.dot", x + y, x=x, y=y)
|
||||
)pbdoc");
|
||||
}
|
||||
|
Reference in New Issue
Block a user