diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index c6b594247..29373f266 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -30,6 +30,10 @@ const std::string& NodeNamer::get_name(const array& x) { return it->second; } +void NodeNamer::set_name(const array& x, std::string n) { + names[x.id()] = std::move(n); +} + void depth_first_traversal( std::function callback, const std::vector& outputs) { @@ -55,7 +59,10 @@ void depth_first_traversal( } } -void print_graph(std::ostream& os, const std::vector& outputs) { +void print_graph( + std::ostream& os, + NodeNamer namer, + const std::vector& outputs) { std::vector tape; std::vector inputs; @@ -69,7 +76,6 @@ void print_graph(std::ostream& os, const std::vector& outputs) { }, outputs); - NodeNamer namer; auto print_arrs = [&namer, &os](std::vector arrs) { for (auto& arr : arrs) { os << namer.get_name(arr); @@ -96,20 +102,39 @@ void print_graph(std::ostream& os, const std::vector& outputs) { } } -void export_to_dot(std::ostream& os, const std::vector& outputs) { +void export_to_dot( + std::ostream& os, + NodeNamer namer, + const std::vector& nodes) { + // Perform one DFS to mark arrays as intermediate if they are used as inputs + // to other arrays. + std::unordered_set intermediate_set; + depth_first_traversal( + [&](const array& x) { + // No primitive so it is an input + if (!x.has_primitive()) { + return; + } + + for (auto& a : x.inputs()) { + intermediate_set.insert(a.id()); + } + }, + nodes); + + // Now we got everything we need to make the graph. Arrays can be one of 3 + // things: + // 1. Inputs, when they have no primitive ie are evaluated + // 2. Intermediates, when they are the intermediate set + // 3. Outputs, if they are not inputs and not intermediates + os << "digraph {" << std::endl; - std::unordered_set output_set; - for (auto& o : outputs) { - output_set.insert(o.id()); - } - std::unordered_set input_set; - NodeNamer namer; depth_first_traversal( [&](const array& x) { if (!x.has_primitive()) { - input_set.insert(x.id()); - os << "{ rank=source; " << namer.get_name(x) << "; }" << std::endl; + os << "{ rank=source; \"" << namer.get_name(x) << "\"; }" + << std::endl; return; } @@ -123,24 +148,26 @@ void export_to_dot(std::ostream& os, const std::vector& outputs) { os << "; }" << std::endl; // Arrows to primitive's inputs for (auto& a : x.inputs()) { - os << namer.get_name(a) << " -> " << x.primitive_id() << std::endl; + os << '"' << namer.get_name(a) << "\" -> " << x.primitive_id() + << std::endl; } } // Point outputs to their primitive for (auto& a : x.outputs()) { os << "{ "; - if (output_set.find(a.id()) != output_set.end()) { + if (intermediate_set.find(a.id()) == intermediate_set.end()) { os << "rank=sink; "; } - os << namer.get_name(a); - os << "; }" << std::endl; + os << '"' << namer.get_name(a); + os << "\"; }" << std::endl; if (x.has_primitive()) { - os << x.primitive_id() << " -> " << namer.get_name(a) << std::endl; + os << x.primitive_id() << " -> \"" << namer.get_name(a) << '"' + << std::endl; } } }, - outputs); + nodes); os << "}"; } diff --git a/mlx/graph_utils.h b/mlx/graph_utils.h index affb5e078..fcbeef15e 100644 --- a/mlx/graph_utils.h +++ b/mlx/graph_utils.h @@ -12,20 +12,55 @@ struct NodeNamer { std::unordered_map names; const std::string& get_name(const array& x); + void set_name(const array& x, std::string n); }; -void print_graph(std::ostream& os, const std::vector& outputs); +void print_graph( + std::ostream& os, + NodeNamer namer, + const std::vector& outputs); -template > -void print_graph(std::ostream& os, Arrays&&... outputs) { - print_graph(os, std::vector{std::forward(outputs)...}); +inline void print_graph(std::ostream& os, const std::vector& outputs) { + print_graph(os, NodeNamer{}, outputs); } -void export_to_dot(std::ostream& os, const std::vector& outputs); +template > +inline void print_graph(std::ostream& os, Arrays&&... outputs) { + print_graph( + os, NodeNamer{}, std::vector{std::forward(outputs)...}); +} template > -void export_to_dot(std::ostream& os, Arrays&&... outputs) { - export_to_dot(os, std::vector{std::forward(outputs)...}); +inline void +print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) { + print_graph( + os, + std::move(namer), + std::vector{std::forward(outputs)...}); +} + +void export_to_dot( + std::ostream& os, + NodeNamer namer, + const std::vector& outputs); + +inline void export_to_dot(std::ostream& os, const std::vector& outputs) { + export_to_dot(os, NodeNamer{}, outputs); +} + +template > +inline void export_to_dot(std::ostream& os, Arrays&&... outputs) { + export_to_dot( + os, NodeNamer{}, std::vector{std::forward(outputs)...}); +} + +template > +inline void +export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) { + export_to_dot( + os, + std::move(namer), + std::vector{std::forward(outputs)...}); } } // namespace mlx::core diff --git a/python/src/export.cpp b/python/src/export.cpp index 172bb4cc4..b2088587a 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -241,14 +241,20 @@ void init_export(nb::module_& m) { )pbdoc"); m.def( "export_to_dot", - [](nb::object file, const nb::args& args) { - std::vector arrays = tree_flatten(args); + [](nb::object file, const nb::args& args, const nb::kwargs& kwargs) { + std::vector arrays = + tree_flatten(nb::make_tuple(args, kwargs)); + mx::NodeNamer namer; + for (const auto& n : kwargs) { + namer.set_name( + nb::cast(n.second), nb::cast(n.first)); + } if (nb::isinstance(file)) { std::ofstream out(nb::cast(file)); - mx::export_to_dot(out, arrays); + mx::export_to_dot(out, std::move(namer), arrays); } else if (nb::hasattr(file, "write")) { std::ostringstream out; - mx::export_to_dot(out, arrays); + mx::export_to_dot(out, std::move(namer), arrays); auto write = file.attr("write"); write(out.str()); } else { @@ -259,19 +265,25 @@ void init_export(nb::module_& m) { }, "file"_a, "args"_a, + "kwargs"_a, R"pbdoc( Export a graph to DOT format for visualization. A variable number of output arrays can be provided for exporting - The graph exported will recursively include all enevaluated inputs of + The graph exported will recursively include all unevaluated inputs of the provided outputs. Args: file (str): The file path to export to. *args (array): The output arrays. + **kwargs (dict[str, array]): Provide some names for arrays in the + graph to make the result easier to parse. Example: >>> a = mx.array(1) + mx.array(2) >>> mx.export_to_dot("graph.dot", a) + >>> x = mx.array(1) + >>> y = mx.array(2) + >>> mx.export_to_dot("graph.dot", x + y, x=x, y=y) )pbdoc"); }