mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 09:07:12 +08:00
[CUDA] Add debug env to save cuda graphs to dot files (#2825)
This commit is contained in:
@@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool use_cuda_graphs() {
|
bool use_cuda_graphs() {
|
||||||
static bool use_graphs = []() {
|
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
|
||||||
}();
|
|
||||||
return use_graphs;
|
return use_graphs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* save_cuda_graphs_dot_file() {
|
||||||
|
static const char* filename = []() -> const char* {
|
||||||
|
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
|
||||||
|
if (env && std::strlen(env) == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return env;
|
||||||
|
}();
|
||||||
|
return filename;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
@@ -421,6 +430,14 @@ void CommandEncoder::commit() {
|
|||||||
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save cuda graph to dot file
|
||||||
|
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
|
||||||
|
static int count = 0;
|
||||||
|
auto path = fmt::format("{}_{}.dot", filename, ++count);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
|
||||||
|
}
|
||||||
|
|
||||||
// Reset state
|
// Reset state
|
||||||
from_nodes_.clear();
|
from_nodes_.clear();
|
||||||
to_nodes_.clear();
|
to_nodes_.clear();
|
||||||
|
|||||||
Reference in New Issue
Block a user