diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 3501f7bf4..7c4f2ab1e 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) { } bool use_cuda_graphs() { - static bool use_graphs = []() { - return env::get_var("MLX_USE_CUDA_GRAPHS", true); - }(); + static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true); return use_graphs; } +const char* save_cuda_graphs_dot_file() { + static const char* filename = []() -> const char* { + const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE"); + if (env && std::strlen(env) == 0) { + return nullptr; + } + return env; + }(); + return filename; +} + } // namespace Device::Device(int device) : device_(device) { @@ -421,6 +430,14 @@ void CommandEncoder::commit() { CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); } + + // Save cuda graph to dot file + if (const char* filename = save_cuda_graphs_dot_file(); filename) { + static int count = 0; + auto path = fmt::format("{}_{}.dot", filename, ++count); + CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0)); + } + // Reset state from_nodes_.clear(); to_nodes_.clear();