// Copyright © 2023 Apple Inc. #pragma once #include #include "mlx/array.h" namespace mlx::core { struct NodeNamer { std::unordered_map names; const std::string& get_name(const array& x); void set_name(const array& x, std::string n); }; void print_graph( std::ostream& os, NodeNamer namer, const std::vector& outputs); inline void print_graph(std::ostream& os, const std::vector& outputs) { print_graph(os, NodeNamer{}, outputs); } template > inline void print_graph(std::ostream& os, Arrays&&... outputs) { print_graph( os, NodeNamer{}, std::vector{std::forward(outputs)...}); } template > inline void print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) { print_graph( os, std::move(namer), std::vector{std::forward(outputs)...}); } void export_to_dot( std::ostream& os, NodeNamer namer, const std::vector& outputs); inline void export_to_dot(std::ostream& os, const std::vector& outputs) { export_to_dot(os, NodeNamer{}, outputs); } template > inline void export_to_dot(std::ostream& os, Arrays&&... outputs) { export_to_dot( os, NodeNamer{}, std::vector{std::forward(outputs)...}); } template > inline void export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) { export_to_dot( os, std::move(namer), std::vector{std::forward(outputs)...}); } } // namespace mlx::core