diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 24b85821e..3501f7bf4 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -115,18 +115,17 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() { } // Use an empty graph node for synchronization - CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; - enc.empty_node_count_++; + CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)}; CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0)); // Insert the concurrent -> empty node dependencies for (auto& from : enc.concurrent_nodes_) { enc.from_nodes_.push_back(from.node); enc.to_nodes_.push_back(empty.node); - enc.graph_key_ += from.id; - enc.graph_key_ += from.node_type; - enc.graph_key_ += empty.id; - enc.graph_key_ += empty.node_type; + enc.graph_deps_key_ += from.id; + enc.graph_deps_key_ += "-"; + enc.graph_deps_key_ += empty.id; + enc.graph_deps_key_ += "-"; } // Insert the input -> concurrent node dependencies without updating output @@ -141,9 +140,6 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() { } void CommandEncoder::insert_graph_dependencies(GraphNode node) { - if (node.node_type == 'G') { - graph_node_count_++; - } node.id = std::to_string(node_count_++); if (in_concurrent_) { concurrent_nodes_.push_back(std::move(node)); @@ -155,6 +151,10 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) { } void CommandEncoder::insert_graph_dependencies(std::vector nodes) { + for (auto& node : nodes) { + graph_nodes_key_ += node.node_type; + graph_nodes_key_ += "-"; + } std::vector deps; { // Dependencies must be added in the same order to produce a consistent @@ -182,10 +182,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { for (auto& to : nodes) { from_nodes_.push_back(from.node); to_nodes_.push_back(to.node); - graph_key_ += from.id; - graph_key_ += from.node_type; - graph_key_ += to.id; - graph_key_ += to.node_type; + graph_deps_key_ += from.id; + graph_deps_key_ += "-"; + graph_deps_key_ += to.id; + graph_deps_key_ += "-"; } } } @@ -309,13 +309,46 @@ void CommandEncoder::add_kernel_node( void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) { cudaGraphNode_t node; CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); - insert_graph_dependencies(GraphNode{node, 'K'}); + insert_graph_dependencies(GraphNode{node, "K"}); } void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { CUgraphNode node; CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); - insert_graph_dependencies(GraphNode{node, 'K'}); + insert_graph_dependencies(GraphNode{node, "K"}); +} + +bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) { + // CUDA graphs do not get updated correctly if a kernel node getting updated + // has a different cluster shape than the node it's being updated with. + size_t num_nodes = 0; + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes)); + if (num_nodes == 0) { + return true; + } + + std::vector nodes(num_nodes); + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes)); + for (const auto& node : nodes) { + cudaGraphNodeType type; + CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type)); + if (type != cudaGraphNodeTypeKernel) { + return false; + } + cudaLaunchAttributeValue cluster_dim; + CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute( + node, cudaLaunchAttributeClusterDimension, &cluster_dim)); + // Only dim.x can be greater than 1 + if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) { + return false; + } + // Only one child node allowed when subgraph uses clusters + if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) { + return false; + } + cluster_dim_x = cluster_dim.clusterDim.x; + } + return true; } void CommandEncoder::add_graph_node(cudaGraph_t child) { @@ -328,8 +361,11 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) { return; } cudaGraphNode_t node; + int cluster_dim_x = 0; + is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x); CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); - insert_graph_dependencies(GraphNode{node, 'G'}); + insert_graph_dependencies( + GraphNode{node, "G" + std::to_string(cluster_dim_x)}); } bool CommandEncoder::needs_commit() { @@ -354,44 +390,45 @@ void CommandEncoder::commit() { from_nodes_.size())); } - graph_key_ += "."; - graph_key_ += std::to_string(node_count_); - graph_key_ += "."; - graph_key_ += std::to_string(graph_node_count_); - graph_key_ += "."; - graph_key_ += std::to_string(empty_node_count_); - - CudaGraphExec& graph_exec = graph_cache_[graph_key_]; - - if (graph_exec != nullptr) { - cudaGraphExecUpdateResult update_result; -#if CUDART_VERSION >= 12000 - cudaGraphExecUpdateResultInfo info; - cudaGraphExecUpdate(graph_exec, graph_, &info); - update_result = info.result; -#else - cudaGraphNode_t error_node; - cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result); -#endif // CUDART_VERSION >= 12000 - if (update_result != cudaGraphExecUpdateSuccess) { - cudaGetLastError(); // reset error - graph_exec.reset(); - } - } - if (graph_exec == nullptr) { - graph_exec.instantiate(graph_); - } device_.make_current(); - CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); + if (!is_graph_updatable_) { + CudaGraphExec graph_exec; + graph_exec.instantiate(graph_); + CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); + } else { + auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_; + auto& graph_exec = graph_cache_[graph_key]; + + if (graph_exec != nullptr) { + cudaGraphExecUpdateResult update_result; +#if CUDART_VERSION >= 12000 + cudaGraphExecUpdateResultInfo info; + cudaGraphExecUpdate(graph_exec, graph_, &info); + update_result = info.result; +#else + cudaGraphNode_t error_node; + cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result); +#endif // CUDART_VERSION >= 12000 + if (update_result != cudaGraphExecUpdateSuccess) { + cudaGetLastError(); // reset error + graph_exec.reset(); + } + } + if (graph_exec == nullptr) { + graph_exec.instantiate(graph_); + } + + CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); + } // Reset state - graph_node_count_ = 0; - empty_node_count_ = 0; from_nodes_.clear(); to_nodes_.clear(); - graph_key_.clear(); + graph_deps_key_.clear(); + graph_nodes_key_.clear(); node_map_.clear(); graph_ = CudaGraph(device_); + is_graph_updatable_ = true; } // Put completion handlers in a batch. diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 196cd799f..e7b14c555 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -106,8 +106,9 @@ class CommandEncoder { cudaGraphNode_t node; // K = kernel // E = empty - // G = subgraph - char node_type; + // G* = subgraph (with metadata) + // Symbols ':', '-' are reserved as separators + std::string node_type; std::string id; }; @@ -119,12 +120,11 @@ class CommandEncoder { CudaGraph graph_; Worker worker_; char node_count_{0}; - char graph_node_count_{0}; - char empty_node_count_{0}; bool in_concurrent_{false}; std::vector from_nodes_; std::vector to_nodes_; - std::string graph_key_; + std::string graph_nodes_key_; + std::string graph_deps_key_; std::vector concurrent_nodes_; std::vector> temporaries_; LRUCache graph_cache_; @@ -132,6 +132,7 @@ class CommandEncoder { std::vector active_outputs_; std::unordered_map node_map_; size_t bytes_in_graph_{0}; + bool is_graph_updatable_{true}; int max_ops_per_graph_; int max_mb_per_graph_; }; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 3ab1a05bc..eb74ae14a 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -5,6 +5,7 @@ #include "mlx/dtype_utils.h" #include +#include namespace mlx::core {