mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-11 15:06:42 +08:00
Force cudaGraphExec reinstantiation when clusters are used (#2813)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -115,18 +115,17 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||
}
|
||||
|
||||
// Use an empty graph node for synchronization
|
||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
||||
enc.empty_node_count_++;
|
||||
CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)};
|
||||
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
|
||||
|
||||
// Insert the concurrent -> empty node dependencies
|
||||
for (auto& from : enc.concurrent_nodes_) {
|
||||
enc.from_nodes_.push_back(from.node);
|
||||
enc.to_nodes_.push_back(empty.node);
|
||||
enc.graph_key_ += from.id;
|
||||
enc.graph_key_ += from.node_type;
|
||||
enc.graph_key_ += empty.id;
|
||||
enc.graph_key_ += empty.node_type;
|
||||
enc.graph_deps_key_ += from.id;
|
||||
enc.graph_deps_key_ += "-";
|
||||
enc.graph_deps_key_ += empty.id;
|
||||
enc.graph_deps_key_ += "-";
|
||||
}
|
||||
|
||||
// Insert the input -> concurrent node dependencies without updating output
|
||||
@@ -141,9 +140,6 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||
}
|
||||
|
||||
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
||||
if (node.node_type == 'G') {
|
||||
graph_node_count_++;
|
||||
}
|
||||
node.id = std::to_string(node_count_++);
|
||||
if (in_concurrent_) {
|
||||
concurrent_nodes_.push_back(std::move(node));
|
||||
@@ -155,6 +151,10 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
||||
}
|
||||
|
||||
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
for (auto& node : nodes) {
|
||||
graph_nodes_key_ += node.node_type;
|
||||
graph_nodes_key_ += "-";
|
||||
}
|
||||
std::vector<GraphNode> deps;
|
||||
{
|
||||
// Dependencies must be added in the same order to produce a consistent
|
||||
@@ -182,10 +182,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
for (auto& to : nodes) {
|
||||
from_nodes_.push_back(from.node);
|
||||
to_nodes_.push_back(to.node);
|
||||
graph_key_ += from.id;
|
||||
graph_key_ += from.node_type;
|
||||
graph_key_ += to.id;
|
||||
graph_key_ += to.node_type;
|
||||
graph_deps_key_ += from.id;
|
||||
graph_deps_key_ += "-";
|
||||
graph_deps_key_ += to.id;
|
||||
graph_deps_key_ += "-";
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -309,13 +309,46 @@ void CommandEncoder::add_kernel_node(
|
||||
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
insert_graph_dependencies(GraphNode{node, "K"});
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
insert_graph_dependencies(GraphNode{node, "K"});
|
||||
}
|
||||
|
||||
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
||||
// CUDA graphs do not get updated correctly if a kernel node getting updated
|
||||
// has a different cluster shape than the node it's being updated with.
|
||||
size_t num_nodes = 0;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
||||
if (num_nodes == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
||||
for (const auto& node : nodes) {
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||
if (type != cudaGraphNodeTypeKernel) {
|
||||
return false;
|
||||
}
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only dim.x can be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
return false;
|
||||
}
|
||||
// Only one child node allowed when subgraph uses clusters
|
||||
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
||||
return false;
|
||||
}
|
||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
@@ -328,8 +361,11 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
return;
|
||||
}
|
||||
cudaGraphNode_t node;
|
||||
int cluster_dim_x = 0;
|
||||
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
insert_graph_dependencies(
|
||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||
}
|
||||
|
||||
bool CommandEncoder::needs_commit() {
|
||||
@@ -354,44 +390,45 @@ void CommandEncoder::commit() {
|
||||
from_nodes_.size()));
|
||||
}
|
||||
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(node_count_);
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(graph_node_count_);
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(empty_node_count_);
|
||||
|
||||
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
|
||||
|
||||
if (graph_exec != nullptr) {
|
||||
cudaGraphExecUpdateResult update_result;
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo info;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
||||
update_result = info.result;
|
||||
#else
|
||||
cudaGraphNode_t error_node;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||
cudaGetLastError(); // reset error
|
||||
graph_exec.reset();
|
||||
}
|
||||
}
|
||||
if (graph_exec == nullptr) {
|
||||
graph_exec.instantiate(graph_);
|
||||
}
|
||||
device_.make_current();
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
if (!is_graph_updatable_) {
|
||||
CudaGraphExec graph_exec;
|
||||
graph_exec.instantiate(graph_);
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
} else {
|
||||
auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_;
|
||||
auto& graph_exec = graph_cache_[graph_key];
|
||||
|
||||
if (graph_exec != nullptr) {
|
||||
cudaGraphExecUpdateResult update_result;
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo info;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
||||
update_result = info.result;
|
||||
#else
|
||||
cudaGraphNode_t error_node;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||
cudaGetLastError(); // reset error
|
||||
graph_exec.reset();
|
||||
}
|
||||
}
|
||||
if (graph_exec == nullptr) {
|
||||
graph_exec.instantiate(graph_);
|
||||
}
|
||||
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
}
|
||||
// Reset state
|
||||
graph_node_count_ = 0;
|
||||
empty_node_count_ = 0;
|
||||
from_nodes_.clear();
|
||||
to_nodes_.clear();
|
||||
graph_key_.clear();
|
||||
graph_deps_key_.clear();
|
||||
graph_nodes_key_.clear();
|
||||
node_map_.clear();
|
||||
graph_ = CudaGraph(device_);
|
||||
is_graph_updatable_ = true;
|
||||
}
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
|
||||
@@ -106,8 +106,9 @@ class CommandEncoder {
|
||||
cudaGraphNode_t node;
|
||||
// K = kernel
|
||||
// E = empty
|
||||
// G = subgraph
|
||||
char node_type;
|
||||
// G* = subgraph (with metadata)
|
||||
// Symbols ':', '-' are reserved as separators
|
||||
std::string node_type;
|
||||
std::string id;
|
||||
};
|
||||
|
||||
@@ -119,12 +120,11 @@ class CommandEncoder {
|
||||
CudaGraph graph_;
|
||||
Worker worker_;
|
||||
char node_count_{0};
|
||||
char graph_node_count_{0};
|
||||
char empty_node_count_{0};
|
||||
bool in_concurrent_{false};
|
||||
std::vector<cudaGraphNode_t> from_nodes_;
|
||||
std::vector<cudaGraphNode_t> to_nodes_;
|
||||
std::string graph_key_;
|
||||
std::string graph_nodes_key_;
|
||||
std::string graph_deps_key_;
|
||||
std::vector<GraphNode> concurrent_nodes_;
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
||||
@@ -132,6 +132,7 @@ class CommandEncoder {
|
||||
std::vector<std::uintptr_t> active_outputs_;
|
||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||
size_t bytes_in_graph_{0};
|
||||
bool is_graph_updatable_{true};
|
||||
int max_ops_per_graph_;
|
||||
int max_mb_per_graph_;
|
||||
};
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <vector>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user