mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Optionally specify names for arrays when exporting (#1749)
This commit is contained in:
parent
058d6ce683
commit
25b3a3e541
@ -30,6 +30,10 @@ const std::string& NodeNamer::get_name(const array& x) {
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NodeNamer::set_name(const array& x, std::string n) {
|
||||||
|
names[x.id()] = std::move(n);
|
||||||
|
}
|
||||||
|
|
||||||
void depth_first_traversal(
|
void depth_first_traversal(
|
||||||
std::function<void(array)> callback,
|
std::function<void(array)> callback,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& outputs) {
|
||||||
@ -55,7 +59,10 @@ void depth_first_traversal(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
void print_graph(
|
||||||
|
std::ostream& os,
|
||||||
|
NodeNamer namer,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
std::vector<array> tape;
|
std::vector<array> tape;
|
||||||
std::vector<array> inputs;
|
std::vector<array> inputs;
|
||||||
|
|
||||||
@ -69,7 +76,6 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
|||||||
},
|
},
|
||||||
outputs);
|
outputs);
|
||||||
|
|
||||||
NodeNamer namer;
|
|
||||||
auto print_arrs = [&namer, &os](std::vector<array> arrs) {
|
auto print_arrs = [&namer, &os](std::vector<array> arrs) {
|
||||||
for (auto& arr : arrs) {
|
for (auto& arr : arrs) {
|
||||||
os << namer.get_name(arr);
|
os << namer.get_name(arr);
|
||||||
@ -96,20 +102,39 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
void export_to_dot(
|
||||||
|
std::ostream& os,
|
||||||
|
NodeNamer namer,
|
||||||
|
const std::vector<array>& nodes) {
|
||||||
|
// Perform one DFS to mark arrays as intermediate if they are used as inputs
|
||||||
|
// to other arrays.
|
||||||
|
std::unordered_set<std::uintptr_t> intermediate_set;
|
||||||
|
depth_first_traversal(
|
||||||
|
[&](const array& x) {
|
||||||
|
// No primitive so it is an input
|
||||||
|
if (!x.has_primitive()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& a : x.inputs()) {
|
||||||
|
intermediate_set.insert(a.id());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
nodes);
|
||||||
|
|
||||||
|
// Now we got everything we need to make the graph. Arrays can be one of 3
|
||||||
|
// things:
|
||||||
|
// 1. Inputs, when they have no primitive ie are evaluated
|
||||||
|
// 2. Intermediates, when they are the intermediate set
|
||||||
|
// 3. Outputs, if they are not inputs and not intermediates
|
||||||
|
|
||||||
os << "digraph {" << std::endl;
|
os << "digraph {" << std::endl;
|
||||||
|
|
||||||
std::unordered_set<std::uintptr_t> output_set;
|
|
||||||
for (auto& o : outputs) {
|
|
||||||
output_set.insert(o.id());
|
|
||||||
}
|
|
||||||
std::unordered_set<std::uintptr_t> input_set;
|
|
||||||
NodeNamer namer;
|
|
||||||
depth_first_traversal(
|
depth_first_traversal(
|
||||||
[&](const array& x) {
|
[&](const array& x) {
|
||||||
if (!x.has_primitive()) {
|
if (!x.has_primitive()) {
|
||||||
input_set.insert(x.id());
|
os << "{ rank=source; \"" << namer.get_name(x) << "\"; }"
|
||||||
os << "{ rank=source; " << namer.get_name(x) << "; }" << std::endl;
|
<< std::endl;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,24 +148,26 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
|||||||
os << "; }" << std::endl;
|
os << "; }" << std::endl;
|
||||||
// Arrows to primitive's inputs
|
// Arrows to primitive's inputs
|
||||||
for (auto& a : x.inputs()) {
|
for (auto& a : x.inputs()) {
|
||||||
os << namer.get_name(a) << " -> " << x.primitive_id() << std::endl;
|
os << '"' << namer.get_name(a) << "\" -> " << x.primitive_id()
|
||||||
|
<< std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Point outputs to their primitive
|
// Point outputs to their primitive
|
||||||
for (auto& a : x.outputs()) {
|
for (auto& a : x.outputs()) {
|
||||||
os << "{ ";
|
os << "{ ";
|
||||||
if (output_set.find(a.id()) != output_set.end()) {
|
if (intermediate_set.find(a.id()) == intermediate_set.end()) {
|
||||||
os << "rank=sink; ";
|
os << "rank=sink; ";
|
||||||
}
|
}
|
||||||
os << namer.get_name(a);
|
os << '"' << namer.get_name(a);
|
||||||
os << "; }" << std::endl;
|
os << "\"; }" << std::endl;
|
||||||
if (x.has_primitive()) {
|
if (x.has_primitive()) {
|
||||||
os << x.primitive_id() << " -> " << namer.get_name(a) << std::endl;
|
os << x.primitive_id() << " -> \"" << namer.get_name(a) << '"'
|
||||||
|
<< std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
outputs);
|
nodes);
|
||||||
|
|
||||||
os << "}";
|
os << "}";
|
||||||
}
|
}
|
||||||
|
@ -12,20 +12,55 @@ struct NodeNamer {
|
|||||||
std::unordered_map<std::uintptr_t, std::string> names;
|
std::unordered_map<std::uintptr_t, std::string> names;
|
||||||
|
|
||||||
const std::string& get_name(const array& x);
|
const std::string& get_name(const array& x);
|
||||||
|
void set_name(const array& x, std::string n);
|
||||||
};
|
};
|
||||||
|
|
||||||
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
void print_graph(
|
||||||
|
std::ostream& os,
|
||||||
|
NodeNamer namer,
|
||||||
|
const std::vector<array>& outputs);
|
||||||
|
|
||||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
inline void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
||||||
void print_graph(std::ostream& os, Arrays&&... outputs) {
|
print_graph(os, NodeNamer{}, outputs);
|
||||||
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
|
inline void print_graph(std::ostream& os, Arrays&&... outputs) {
|
||||||
|
print_graph(
|
||||||
|
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
|
}
|
||||||
|
|
||||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
void export_to_dot(std::ostream& os, Arrays&&... outputs) {
|
inline void
|
||||||
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
|
||||||
|
print_graph(
|
||||||
|
os,
|
||||||
|
std::move(namer),
|
||||||
|
std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_to_dot(
|
||||||
|
std::ostream& os,
|
||||||
|
NodeNamer namer,
|
||||||
|
const std::vector<array>& outputs);
|
||||||
|
|
||||||
|
inline void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
||||||
|
export_to_dot(os, NodeNamer{}, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
|
inline void export_to_dot(std::ostream& os, Arrays&&... outputs) {
|
||||||
|
export_to_dot(
|
||||||
|
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
|
inline void
|
||||||
|
export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
|
||||||
|
export_to_dot(
|
||||||
|
os,
|
||||||
|
std::move(namer),
|
||||||
|
std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -241,14 +241,20 @@ void init_export(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"export_to_dot",
|
"export_to_dot",
|
||||||
[](nb::object file, const nb::args& args) {
|
[](nb::object file, const nb::args& args, const nb::kwargs& kwargs) {
|
||||||
std::vector<mx::array> arrays = tree_flatten(args);
|
std::vector<mx::array> arrays =
|
||||||
|
tree_flatten(nb::make_tuple(args, kwargs));
|
||||||
|
mx::NodeNamer namer;
|
||||||
|
for (const auto& n : kwargs) {
|
||||||
|
namer.set_name(
|
||||||
|
nb::cast<mx::array>(n.second), nb::cast<std::string>(n.first));
|
||||||
|
}
|
||||||
if (nb::isinstance<nb::str>(file)) {
|
if (nb::isinstance<nb::str>(file)) {
|
||||||
std::ofstream out(nb::cast<std::string>(file));
|
std::ofstream out(nb::cast<std::string>(file));
|
||||||
mx::export_to_dot(out, arrays);
|
mx::export_to_dot(out, std::move(namer), arrays);
|
||||||
} else if (nb::hasattr(file, "write")) {
|
} else if (nb::hasattr(file, "write")) {
|
||||||
std::ostringstream out;
|
std::ostringstream out;
|
||||||
mx::export_to_dot(out, arrays);
|
mx::export_to_dot(out, std::move(namer), arrays);
|
||||||
auto write = file.attr("write");
|
auto write = file.attr("write");
|
||||||
write(out.str());
|
write(out.str());
|
||||||
} else {
|
} else {
|
||||||
@ -259,19 +265,25 @@ void init_export(nb::module_& m) {
|
|||||||
},
|
},
|
||||||
"file"_a,
|
"file"_a,
|
||||||
"args"_a,
|
"args"_a,
|
||||||
|
"kwargs"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Export a graph to DOT format for visualization.
|
Export a graph to DOT format for visualization.
|
||||||
|
|
||||||
A variable number of output arrays can be provided for exporting
|
A variable number of output arrays can be provided for exporting
|
||||||
The graph exported will recursively include all enevaluated inputs of
|
The graph exported will recursively include all unevaluated inputs of
|
||||||
the provided outputs.
|
the provided outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (str): The file path to export to.
|
file (str): The file path to export to.
|
||||||
*args (array): The output arrays.
|
*args (array): The output arrays.
|
||||||
|
**kwargs (dict[str, array]): Provide some names for arrays in the
|
||||||
|
graph to make the result easier to parse.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> a = mx.array(1) + mx.array(2)
|
>>> a = mx.array(1) + mx.array(2)
|
||||||
>>> mx.export_to_dot("graph.dot", a)
|
>>> mx.export_to_dot("graph.dot", a)
|
||||||
|
>>> x = mx.array(1)
|
||||||
|
>>> y = mx.array(2)
|
||||||
|
>>> mx.export_to_dot("graph.dot", x + y, x=x, y=y)
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user