Multi output primitives (#330)

* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-01-08 16:39:08 -08:00
committed by GitHub
parent f45f70f133
commit f099ebe535
26 changed files with 2313 additions and 1039 deletions

View File

@@ -12,13 +12,11 @@
namespace mlx::core {
using OptionalArrayRef = std::optional<std::reference_wrapper<const array>>;
struct ArrayNames {
struct NodeNamer {
std::unordered_map<std::uintptr_t, std::string> names;
std::string get_name(const array& x) {
auto it = names.find(x.id());
std::string get_name(uintptr_t id) {
auto it = names.find(id);
if (it == names.end()) {
// Get the next name in the sequence
// [A, B, ..., Z, AA, AB, ...]
@@ -29,45 +27,42 @@ struct ArrayNames {
var_num = (var_num - 1) / 26;
}
std::string name(letters.rbegin(), letters.rend());
names.insert({x.id(), name});
names.insert({id, name});
return name;
}
return it->second;
}
std::string get_name(const array& x) {
return get_name(x.id());
}
};
void depth_first_traversal(
std::function<void(OptionalArrayRef, const array&, int)> callback,
std::function<void(array)> callback,
const std::vector<array>& outputs) {
std::function<void(OptionalArrayRef, const array&, int)> recurse;
std::function<void(const array&)> recurse;
std::unordered_set<std::uintptr_t> cache;
recurse = [&](OptionalArrayRef parent, const array& x, int input_index) {
recurse = [&](const array& x) {
auto id = x.id();
if (cache.find(id) != cache.end()) {
return;
}
cache.insert(id);
for (int i = 0; i < x.inputs().size(); i++) {
recurse(x, x.inputs()[i], i);
for (auto& s : x.siblings()) {
cache.insert(s.id());
}
callback(parent, x, input_index);
for (auto& in : x.inputs()) {
recurse(in);
}
callback(x);
};
for (auto x : outputs) {
recurse(std::nullopt, x, 0);
for (auto& o : outputs) {
recurse(o);
}
}
void depth_first_traversal(
std::function<void(const array&)> callback,
const std::vector<array>& outputs) {
depth_first_traversal(
[&callback](OptionalArrayRef p, const array& x, int input_index) {
callback(x);
},
outputs);
}
void print_graph(std::ostream& os, const std::vector<array>& outputs) {
std::vector<array> tape;
std::vector<array> inputs;
@@ -82,15 +77,11 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
},
outputs);
ArrayNames namer;
auto print_arr = [&namer, &os](const array& a) {
os << namer.get_name(a);
os << " [" << a.shape() << ", " << a.dtype() << "]";
};
auto print_arrs = [&](const std::vector<array>& arrs) {
NodeNamer namer;
auto print_arrs = [&namer, &os](std::vector<array> arrs) {
for (auto& arr : arrs) {
print_arr(arr);
os << namer.get_name(arr);
os << " [" << arr.shape() << ", " << arr.dtype() << "]";
if (&arr != &arrs.back()) {
os << ", ";
}
@@ -108,7 +99,7 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
os << " ";
print_arrs(arr.inputs());
os << " -> ";
print_arr(arr);
print_arrs(arr.outputs());
os << "\n";
}
}
@@ -116,26 +107,47 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
os << "digraph {" << std::endl;
ArrayNames namer;
std::unordered_set<std::uintptr_t> output_set;
for (auto& o : outputs) {
output_set.insert(o.id());
}
std::unordered_set<std::uintptr_t> input_set;
NodeNamer namer;
depth_first_traversal(
[&namer, &os](auto parent, const array& x, int input_index) {
os << "{ ";
[&](const array& x) {
if (!x.has_primitive()) {
os << "rank=source; ";
input_set.insert(x.id());
os << "{ rank=source; " << namer.get_name(x) << "; }" << std::endl;
return;
}
if (!parent) {
os << "rank=sink; ";
}
os << namer.get_name(x);
// Node for primitive
if (x.has_primitive()) {
os << "{ ";
os << namer.get_name(x.primitive_id());
os << " [label =\"";
x.primitive().print(os);
os << "\"]";
os << "\", shape=rectangle]";
os << "; }" << std::endl;
// Arrows to primitive's inputs
for (auto& a : x.inputs()) {
os << namer.get_name(x.primitive_id()) << " -> "
<< namer.get_name(a) << std::endl;
}
}
os << "; }" << std::endl;
for (auto c : x.inputs()) {
os << namer.get_name(c) << " -> " << namer.get_name(x) << std::endl;
// Point outputs to their primitive
for (auto& a : x.outputs()) {
os << "{ ";
if (output_set.find(a.id()) != output_set.end()) {
os << "rank=sink; ";
}
os << namer.get_name(a);
os << "; }" << std::endl;
if (x.has_primitive()) {
os << namer.get_name(a) << " -> "
<< namer.get_name(x.primitive_id()) << std::endl;
}
}
},
outputs);