diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 246f9bca2..0a38bd0e6 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -318,46 +318,52 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { insert_graph_dependencies(GraphNode{node, "K"}); } -bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) { - // CUDA graphs do not get updated correctly if a kernel node getting updated - // has a different cluster shape than the node it's being updated with. +std::pair subgraph_to_key(cudaGraph_t graph) { + // Constructs a key representing the nodes of a sub-graph. + // Also checks if the sub-graph is updatable as CUDA graphs do not get + // updated correctly if a kernel node getting updated has a different cluster + // shape than the node it's being updated with. + std::string key = "("; size_t num_nodes = 0; CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes)); if (num_nodes == 0) { - return true; + return {key + ")", true}; } - + bool is_updatable = true; std::vector nodes(num_nodes); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes)); for (const auto& node : nodes) { + if (!is_updatable) { + break; + } cudaGraphNodeType type; CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type)); if (type == cudaGraphNodeTypeGraph) { // Try to be updatable for a structure like graph -> graph -> kernel - if (num_nodes > 1) { - return false; - } cudaGraph_t child; CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child)); - return is_graph_updatable(child, cluster_dim_x); + auto [subkey, sub_is_updatable] = subgraph_to_key(child); + is_updatable &= sub_is_updatable; + key += subkey; + } else if (type == cudaGraphNodeTypeMemset) { + key += "M"; } else if (type != cudaGraphNodeTypeKernel) { - return false; + is_updatable = false; } else { cudaLaunchAttributeValue cluster_dim; CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute( node, cudaLaunchAttributeClusterDimension, &cluster_dim)); - // Only dim.x can be greater than 1 + // Only allow dim.x to be greater than 1 if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) { - return false; + is_updatable = false; + } else { + key += "K"; + key += std::to_string(cluster_dim.clusterDim.x); } - // Only one child node allowed when subgraph uses clusters - if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) { - return false; - } - cluster_dim_x = cluster_dim.clusterDim.x; } } - return true; + key += ")"; + return {key, is_updatable}; } void CommandEncoder::add_graph_node(cudaGraph_t child) { @@ -370,11 +376,10 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) { return; } cudaGraphNode_t node; - int cluster_dim_x = 0; - is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x); + auto [sub_graph_key, is_updatable] = subgraph_to_key(child); + is_graph_updatable_ &= is_updatable; CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); - insert_graph_dependencies( - GraphNode{node, "G" + std::to_string(cluster_dim_x)}); + insert_graph_dependencies(GraphNode{node, sub_graph_key}); } bool CommandEncoder::needs_commit() { diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index e7b14c555..7d317d362 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -106,7 +106,7 @@ class CommandEncoder { cudaGraphNode_t node; // K = kernel // E = empty - // G* = subgraph (with metadata) + // () = subgraph (with metadata) // Symbols ':', '-' are reserved as separators std::string node_type; std::string id; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 35078fdf6..934f68acd 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -80,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) { } void CudaGraph::end_capture(cudaStream_t stream) { - assert(handle_ == nullptr); CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_)); }