[CUDA] Enable more graphs to be updatable (#2883)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled

This commit is contained in:
Awni Hannun
2025-12-08 06:18:01 -08:00
committed by GitHub
parent a4b3bc969b
commit 27232db1ba
3 changed files with 28 additions and 24 deletions

View File

@@ -318,46 +318,52 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
insert_graph_dependencies(GraphNode{node, "K"}); insert_graph_dependencies(GraphNode{node, "K"});
} }
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) { std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
// CUDA graphs do not get updated correctly if a kernel node getting updated // Constructs a key representing the nodes of a sub-graph.
// has a different cluster shape than the node it's being updated with. // Also checks if the sub-graph is updatable as CUDA graphs do not get
// updated correctly if a kernel node getting updated has a different cluster
// shape than the node it's being updated with.
std::string key = "(";
size_t num_nodes = 0; size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes)); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) { if (num_nodes == 0) {
return true; return {key + ")", true};
} }
bool is_updatable = true;
std::vector<cudaGraphNode_t> nodes(num_nodes); std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes)); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) { for (const auto& node : nodes) {
if (!is_updatable) {
break;
}
cudaGraphNodeType type; cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type)); CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) { if (type == cudaGraphNodeTypeGraph) {
// Try to be updatable for a structure like graph -> graph -> kernel // Try to be updatable for a structure like graph -> graph -> kernel
if (num_nodes > 1) {
return false;
}
cudaGraph_t child; cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child)); CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
return is_graph_updatable(child, cluster_dim_x); auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
} else if (type == cudaGraphNodeTypeMemset) {
key += "M";
} else if (type != cudaGraphNodeTypeKernel) { } else if (type != cudaGraphNodeTypeKernel) {
return false; is_updatable = false;
} else { } else {
cudaLaunchAttributeValue cluster_dim; cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute( CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim)); node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only dim.x can be greater than 1 // Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) { if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
return false; is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
} }
// Only one child node allowed when subgraph uses clusters
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
} }
} }
return true; key += ")";
return {key, is_updatable};
} }
void CommandEncoder::add_graph_node(cudaGraph_t child) { void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -370,11 +376,10 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return; return;
} }
cudaGraphNode_t node; cudaGraphNode_t node;
int cluster_dim_x = 0; auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x); is_graph_updatable_ &= is_updatable;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies( insert_graph_dependencies(GraphNode{node, sub_graph_key});
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
} }
bool CommandEncoder::needs_commit() { bool CommandEncoder::needs_commit() {

View File

@@ -106,7 +106,7 @@ class CommandEncoder {
cudaGraphNode_t node; cudaGraphNode_t node;
// K = kernel // K = kernel
// E = empty // E = empty
// G* = subgraph (with metadata) // () = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators // Symbols ':', '-' are reserved as separators
std::string node_type; std::string node_type;
std::string id; std::string id;

View File

@@ -80,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
} }
void CudaGraph::end_capture(cudaStream_t stream) { void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_)); CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
} }