mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-10 05:59:04 +08:00
[CUDA] Add debug env to save cuda graphs to dot files (#2825)
This commit is contained in:
@@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
}
|
||||
|
||||
bool use_cuda_graphs() {
|
||||
static bool use_graphs = []() {
|
||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
}();
|
||||
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
return use_graphs;
|
||||
}
|
||||
|
||||
const char* save_cuda_graphs_dot_file() {
|
||||
static const char* filename = []() -> const char* {
|
||||
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
|
||||
if (env && std::strlen(env) == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return env;
|
||||
}();
|
||||
return filename;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
@@ -421,6 +430,14 @@ void CommandEncoder::commit() {
|
||||
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
}
|
||||
|
||||
// Save cuda graph to dot file
|
||||
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
|
||||
static int count = 0;
|
||||
auto path = fmt::format("{}_{}.dot", filename, ++count);
|
||||
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
|
||||
}
|
||||
|
||||
// Reset state
|
||||
from_nodes_.clear();
|
||||
to_nodes_.clear();
|
||||
|
||||
Reference in New Issue
Block a user