MLX
 
Loading...
Searching...
No Matches
graph_utils.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <unordered_map>
6
7#include "mlx/array.h"
8
9namespace mlx::core {
10
11struct NodeNamer {
12 std::unordered_map<std::uintptr_t, std::string> names;
13
14 const std::string& get_name(const array& x);
15 void set_name(const array& x, std::string n);
16};
17
19 std::ostream& os,
20 NodeNamer namer,
21 const std::vector<array>& outputs);
22
23inline void print_graph(std::ostream& os, const std::vector<array>& outputs) {
24 print_graph(os, NodeNamer{}, outputs);
25}
26
27template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
28inline void print_graph(std::ostream& os, Arrays&&... outputs) {
30 os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
31}
32
33template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
34inline void
35print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
37 os,
38 std::move(namer),
39 std::vector<array>{std::forward<Arrays>(outputs)...});
40}
41
43 std::ostream& os,
44 NodeNamer namer,
45 const std::vector<array>& outputs);
46
47inline void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
48 export_to_dot(os, NodeNamer{}, outputs);
49}
50
51template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
52inline void export_to_dot(std::ostream& os, Arrays&&... outputs) {
54 os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
55}
56
57template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
58inline void
59export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
61 os,
62 std::move(namer),
63 std::vector<array>{std::forward<Arrays>(outputs)...});
64}
65
66} // namespace mlx::core
Definition array.h:24
Definition allocator.h:7
void export_to_dot(std::ostream &os, NodeNamer namer, const std::vector< array > &outputs)
void print_graph(std::ostream &os, NodeNamer namer, const std::vector< array > &outputs)
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:630
Definition graph_utils.h:11
const std::string & get_name(const array &x)
std::unordered_map< std::uintptr_t, std::string > names
Definition graph_utils.h:12
void set_name(const array &x, std::string n)