[CUDA] Add debug env to save cuda graphs to dot files (#2825)

This commit is contained in:
Cheng
2025-11-25 15:22:36 +09:00
committed by GitHub
parent bca205e287
commit 23a9168d34

View File

@@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
} }
bool use_cuda_graphs() { bool use_cuda_graphs() {
static bool use_graphs = []() { static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
}();
return use_graphs; return use_graphs;
} }
const char* save_cuda_graphs_dot_file() {
static const char* filename = []() -> const char* {
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
if (env && std::strlen(env) == 0) {
return nullptr;
}
return env;
}();
return filename;
}
} // namespace } // namespace
Device::Device(int device) : device_(device) { Device::Device(int device) : device_(device) {
@@ -421,6 +430,14 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
} }
// Save cuda graph to dot file
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
static int count = 0;
auto path = fmt::format("{}_{}.dot", filename, ++count);
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
}
// Reset state // Reset state
from_nodes_.clear(); from_nodes_.clear();
to_nodes_.clear(); to_nodes_.clear();