MLX
Loading...
Searching...
No Matches
graph_utils.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include "mlx/array.h"
6
7namespace mlx::core {
8
9struct NodeNamer {
10 std::unordered_map<std::uintptr_t, std::string> names;
11
12 const std::string& get_name(const array& x);
13};
14
15void print_graph(std::ostream& os, const std::vector<array>& outputs);
16
17template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
18void print_graph(std::ostream& os, Arrays&&... outputs) {
19 print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
20}
21
22void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
23
24template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
25void export_to_dot(std::ostream& os, Arrays&&... outputs) {
26 export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
27}
28
29} // namespace mlx::core
Definition array.h:20
Definition allocator.h:7
void export_to_dot(std::ostream &os, const std::vector< array > &outputs)
void print_graph(std::ostream &os, const std::vector< array > &outputs)
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:566
Definition graph_utils.h:9
const std::string & get_name(const array &x)
std::unordered_map< std::uintptr_t, std::string > names
Definition graph_utils.h:10