mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Enable more graphs to be updatable (#2883)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
This commit is contained in:
@@ -318,46 +318,52 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
|||||||
insert_graph_dependencies(GraphNode{node, "K"});
|
insert_graph_dependencies(GraphNode{node, "K"});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
||||||
// CUDA graphs do not get updated correctly if a kernel node getting updated
|
// Constructs a key representing the nodes of a sub-graph.
|
||||||
// has a different cluster shape than the node it's being updated with.
|
// Also checks if the sub-graph is updatable as CUDA graphs do not get
|
||||||
|
// updated correctly if a kernel node getting updated has a different cluster
|
||||||
|
// shape than the node it's being updated with.
|
||||||
|
std::string key = "(";
|
||||||
size_t num_nodes = 0;
|
size_t num_nodes = 0;
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
||||||
if (num_nodes == 0) {
|
if (num_nodes == 0) {
|
||||||
return true;
|
return {key + ")", true};
|
||||||
}
|
}
|
||||||
|
bool is_updatable = true;
|
||||||
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
||||||
for (const auto& node : nodes) {
|
for (const auto& node : nodes) {
|
||||||
|
if (!is_updatable) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
cudaGraphNodeType type;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||||
if (type == cudaGraphNodeTypeGraph) {
|
if (type == cudaGraphNodeTypeGraph) {
|
||||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||||
if (num_nodes > 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
cudaGraph_t child;
|
cudaGraph_t child;
|
||||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||||
return is_graph_updatable(child, cluster_dim_x);
|
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||||
|
is_updatable &= sub_is_updatable;
|
||||||
|
key += subkey;
|
||||||
|
} else if (type == cudaGraphNodeTypeMemset) {
|
||||||
|
key += "M";
|
||||||
} else if (type != cudaGraphNodeTypeKernel) {
|
} else if (type != cudaGraphNodeTypeKernel) {
|
||||||
return false;
|
is_updatable = false;
|
||||||
} else {
|
} else {
|
||||||
cudaLaunchAttributeValue cluster_dim;
|
cudaLaunchAttributeValue cluster_dim;
|
||||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||||
// Only dim.x can be greater than 1
|
// Only allow dim.x to be greater than 1
|
||||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||||
return false;
|
is_updatable = false;
|
||||||
}
|
} else {
|
||||||
// Only one child node allowed when subgraph uses clusters
|
key += "K";
|
||||||
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
key += std::to_string(cluster_dim.clusterDim.x);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
}
|
||||||
|
key += ")";
|
||||||
|
return {key, is_updatable};
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
@@ -370,11 +376,10 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
int cluster_dim_x = 0;
|
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
|
||||||
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
|
is_graph_updatable_ &= is_updatable;
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(
|
insert_graph_dependencies(GraphNode{node, sub_graph_key});
|
||||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CommandEncoder::needs_commit() {
|
bool CommandEncoder::needs_commit() {
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ class CommandEncoder {
|
|||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
// K = kernel
|
// K = kernel
|
||||||
// E = empty
|
// E = empty
|
||||||
// G* = subgraph (with metadata)
|
// () = subgraph (with metadata)
|
||||||
// Symbols ':', '-' are reserved as separators
|
// Symbols ':', '-' are reserved as separators
|
||||||
std::string node_type;
|
std::string node_type;
|
||||||
std::string id;
|
std::string id;
|
||||||
|
|||||||
@@ -80,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CudaGraph::end_capture(cudaStream_t stream) {
|
void CudaGraph::end_capture(cudaStream_t stream) {
|
||||||
assert(handle_ == nullptr);
|
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user