From ae9dbb1a9b2be72da419bc43ecba10b1e68b4a0a Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 17 Jul 2025 23:48:37 -0700 Subject: [PATCH] Fix recording cudnn conv --- mlx/backend/cuda/conv.cpp | 54 ++++++++++++++++++++++++------------- mlx/backend/cuda/device.cpp | 51 ++++++++++++++++++++++------------- mlx/backend/cuda/device.h | 5 ++++ tests/ops_tests.cpp | 3 +++ 4 files changed, 76 insertions(+), 37 deletions(-) diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 3c5d4e532..c2043b802 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -20,6 +20,10 @@ namespace cu { using namespace cudnn_frontend; +// The populate_cuda_graph API is supposed to integrate cuDNN with CUDA graph +// but it does not seems to support convolution yet. +#define CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API 0 + #define CHECK_CUDNN_FE_ERROR(cmd) \ if (cmd.is_bad()) { \ throw std::runtime_error(fmt::format("{} failed.", #cmd)); \ @@ -29,7 +33,7 @@ auto swapaxes(const array& in, int axis1, int axis2) { std::vector axes(in.ndim()); std::iota(axes.begin(), axes.end(), 0); std::swap(axes[axis1], axes[axis2]); - std::vector shape(axes.size()); + std::vector shape(in.ndim()); std::vector strides(in.ndim()); for (size_t ax = 0; ax < axes.size(); ++ax) { shape[ax] = in.shape()[axes[ax]]; @@ -59,13 +63,18 @@ class Convolution { bool is_half = dtype == float16 || dtype == bfloat16; graph_.set_io_data_type(cudnn_type) + .set_intermediate_data_type(cudnn_type) .set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type); +#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API + graph_.select_behavior_notes( + {BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); +#endif input_attr_ = graph_.tensor(graph::Tensor_attributes() - .set_name("input") + .set_data_type(cudnn_type) .set_dim(input_shape) .set_stride(input_strides)); filter_attr_ = graph_.tensor(graph::Tensor_attributes() - .set_name("filter") + .set_data_type(cudnn_type) .set_dim(filter_shape) .set_stride(filter_strides)); @@ -75,10 +84,10 @@ class Convolution { .set_stride(stride) .set_dilation(dilation); output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options); - output_attr_->set_output(true); - output_attr_->set_data_type(cudnn_type); - output_attr_->set_dim(output_shape); - output_attr_->set_stride(output_strides); + output_attr_->set_output(true) + .set_data_type(cudnn_type) + .set_dim(output_shape) + .set_stride(output_strides); CHECK_CUDNN_FE_ERROR(graph_.validate()); CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_)); @@ -93,23 +102,29 @@ class Convolution { const void* input, const void* filter, void* output) { - float alpha = 1; - float beta = 0; - array workspace( allocator::malloc(workspace_size_), {static_cast(workspace_size_)}, int8); encoder.add_temporary(workspace); - std::unordered_map ptr_map{ + std::unordered_map variant_pack{ {input_attr_->get_uid(), const_cast(input)}, {filter_attr_->get_uid(), const_cast(filter)}, {output_attr_->get_uid(), output}}; +#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API + cudaGraph_t cudnn_cuda_graph; + cudaGraphCreate(&cudnn_cuda_graph, 0); + CHECK_CUDNN_FE_ERROR(graph_.populate_cuda_graph( + handle_, variant_pack, workspace.data(), cudnn_cuda_graph)); + encoder.add_graph_node(cudnn_cuda_graph); + cudaGraphDestroy(cudnn_cuda_graph); +#else auto capture = encoder.capture_context(); CHECK_CUDNN_FE_ERROR( - graph_.execute(handle_, ptr_map, workspace.data())); + graph_.execute(handle_, variant_pack, workspace.data())); +#endif } private: @@ -158,18 +173,19 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); assert(inputs.size() == 2); - const auto& in = inputs[0]; - const auto& wt = inputs[1]; + const array& input = inputs[0]; + const array& filter = inputs[1]; out.set_data(allocator::malloc(out.nbytes())); - int ndim = in.ndim(); - auto [input_shape, input_strides] = cu::swapaxes(in, 1, ndim - 1); - auto [filter_shape, filter_strides] = cu::swapaxes(wt, 1, ndim - 1); + // cuDNN requires dims to be passed as NCHW. + int ndim = input.ndim(); + auto [input_shape, input_strides] = cu::swapaxes(input, 1, ndim - 1); + auto [filter_shape, filter_strides] = cu::swapaxes(filter, 1, ndim - 1); auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1); cu::Convolution conv( cu::device(s.device), - in.dtype(), + input.dtype(), input_shape, input_strides, filter_shape, @@ -181,7 +197,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { std::vector(padding_hi_.begin(), padding_hi_.end()), std::vector(kernel_dilation_.begin(), kernel_dilation_.end()), groups_); - conv.run(encoder, in.data(), wt.data(), out.data()); + conv.run(encoder, input.data(), filter.data(), out.data()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index ff0ff695c..a9066cdd4 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -86,23 +86,26 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::~CaptureContext() { CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); + std::unique_ptr graph_freer( + &graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); }); + + // Extract and add as single kernel node when possible. size_t num_nodes; CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes)); if (num_nodes == 1) { cudaGraphNode_t captured_node; CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes)); - CUDA_KERNEL_NODE_PARAMS params; - CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms)); - cudaGraphNode_t node; - CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms)); - enc.insert_graph_dependencies(GraphNode{node, 'K'}); - } else { - cudaGraphNode_t node; - CHECK_CUDA_ERROR( - cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph)); - enc.insert_graph_dependencies(GraphNode{node, 'G'}); + cudaGraphNodeType type; + CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type)); + if (type == cudaGraphNodeTypeKernel) { + CUDA_KERNEL_NODE_PARAMS params; + CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms)); + enc.add_kernel_node(params); + return; + } } - CHECK_CUDA_ERROR(cudaGraphDestroy(graph)); + // Otherwise add the captured graph as subgraph. + enc.add_graph_node(graph); } CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) @@ -236,10 +239,7 @@ void CommandEncoder::add_kernel_node( kernel_params.gridDim = grid_dim; kernel_params.blockDim = block_dim; kernel_params.kernelParams = params; - cudaGraphNode_t node; - CHECK_CUDA_ERROR( - cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); - insert_graph_dependencies(GraphNode{node, 'K'}); + add_kernel_node(kernel_params); } void CommandEncoder::add_kernel_node( @@ -256,12 +256,27 @@ void CommandEncoder::add_kernel_node( kernel_params.blockDimY = block_dim.y; kernel_params.blockDimZ = block_dim.z; kernel_params.kernelParams = params; - CUgraphNode node; - CHECK_CUDA_ERROR( - cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + add_kernel_node(kernel_params); +} + +void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) { + cudaGraphNode_t node; + CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); insert_graph_dependencies(GraphNode{node, 'K'}); } +void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { + CUgraphNode node; + CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + +void CommandEncoder::add_graph_node(cudaGraph_t child) { + cudaGraphNode_t node; + CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); + insert_graph_dependencies(GraphNode{node, 'G'}); +} + void CommandEncoder::commit() { if (!temporaries_.empty()) { add_completed_handler([temporaries = std::move(temporaries_)]() {}); diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index a937ede83..e3456f7ed 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -66,6 +66,11 @@ class CommandEncoder { void add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); + // Low-level graph helpers. + void add_kernel_node(const cudaKernelNodeParams& params); + void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params); + void add_graph_node(cudaGraph_t child); + void add_temporary(const array& arr) { temporaries_.push_back(arr.data_shared_ptr()); } diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 969bc2ba7..f4c34a279 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -4,6 +4,7 @@ #define _USE_MATH_DEFINES #include +#include #include #include "doctest/doctest.h" @@ -3642,6 +3643,7 @@ TEST_CASE("test conv1d") { {1, 3, 4}); auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups); + std::cout << out << std::endl; CHECK(allclose(out, expected).item()); } @@ -3720,6 +3722,7 @@ TEST_CASE("test conv2d") { auto expected = array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4}); auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups); + std::cout << out << std::endl; CHECK(allclose(out, expected).item()); }