MLX
Loading...
Searching...
No Matches
graph_utils.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <unordered_map>
6
7#include "mlx/array.h"
8
9namespace mlx::core {
10
11struct NodeNamer {
12 std::unordered_map<std::uintptr_t, std::string> names;
13
14 const std::string& get_name(const array& x);
15};
16
17void print_graph(std::ostream& os, const std::vector<array>& outputs);
18
19template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
20void print_graph(std::ostream& os, Arrays&&... outputs) {
21 print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
22}
23
24void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
25
26template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
27void export_to_dot(std::ostream& os, Arrays&&... outputs) {
28 export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
29}
30
31} // namespace mlx::core
Definition array.h:23
Definition allocator.h:7
void export_to_dot(std::ostream &os, const std::vector< array > &outputs)
void print_graph(std::ostream &os, const std::vector< array > &outputs)
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:614
Definition graph_utils.h:11
const std::string & get_name(const array &x)
std::unordered_map< std::uintptr_t, std::string > names
Definition graph_utils.h:12