mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Fix stub generation, change graph exporting for arrows to go to outputs (#455)
This commit is contained in:
@@ -15,8 +15,8 @@ namespace mlx::core {
|
||||
struct NodeNamer {
|
||||
std::unordered_map<std::uintptr_t, std::string> names;
|
||||
|
||||
std::string get_name(uintptr_t id) {
|
||||
auto it = names.find(id);
|
||||
std::string get_name(const array& x) {
|
||||
auto it = names.find(x.id());
|
||||
if (it == names.end()) {
|
||||
// Get the next name in the sequence
|
||||
// [A, B, ..., Z, AA, AB, ...]
|
||||
@@ -27,15 +27,11 @@ struct NodeNamer {
|
||||
var_num = (var_num - 1) / 26;
|
||||
}
|
||||
std::string name(letters.rbegin(), letters.rend());
|
||||
names.insert({id, name});
|
||||
names.insert({x.id(), name});
|
||||
return name;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::string get_name(const array& x) {
|
||||
return get_name(x.id());
|
||||
}
|
||||
};
|
||||
|
||||
void depth_first_traversal(
|
||||
@@ -124,15 +120,14 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
||||
// Node for primitive
|
||||
if (x.has_primitive()) {
|
||||
os << "{ ";
|
||||
os << namer.get_name(x.primitive_id());
|
||||
os << x.primitive_id();
|
||||
os << " [label =\"";
|
||||
x.primitive().print(os);
|
||||
os << "\", shape=rectangle]";
|
||||
os << "; }" << std::endl;
|
||||
// Arrows to primitive's inputs
|
||||
for (auto& a : x.inputs()) {
|
||||
os << namer.get_name(x.primitive_id()) << " -> "
|
||||
<< namer.get_name(a) << std::endl;
|
||||
os << namer.get_name(a) << " -> " << x.primitive_id() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,8 +140,7 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
||||
os << namer.get_name(a);
|
||||
os << "; }" << std::endl;
|
||||
if (x.has_primitive()) {
|
||||
os << namer.get_name(a) << " -> "
|
||||
<< namer.get_name(x.primitive_id()) << std::endl;
|
||||
os << x.primitive_id() << " -> " << namer.get_name(a) << std::endl;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
Reference in New Issue
Block a user