Fix recording cudnn conv

This commit is contained in:
Cheng
2025-07-17 23:48:37 -07:00
parent 6571df6ad7
commit ae9dbb1a9b
4 changed files with 76 additions and 37 deletions

View File

@@ -20,6 +20,10 @@ namespace cu {
using namespace cudnn_frontend;
// The populate_cuda_graph API is supposed to integrate cuDNN with CUDA graph
// but it does not seems to support convolution yet.
#define CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API 0
#define CHECK_CUDNN_FE_ERROR(cmd) \
if (cmd.is_bad()) { \
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
@@ -29,7 +33,7 @@ auto swapaxes(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
std::vector<int64_t> shape(axes.size());
std::vector<int64_t> shape(in.ndim());
std::vector<int64_t> strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
@@ -59,13 +63,18 @@ class Convolution {
bool is_half = dtype == float16 || dtype == bfloat16;
graph_.set_io_data_type(cudnn_type)
.set_intermediate_data_type(cudnn_type)
.set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type);
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
graph_.select_behavior_notes(
{BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
#endif
input_attr_ = graph_.tensor(graph::Tensor_attributes()
.set_name("input")
.set_data_type(cudnn_type)
.set_dim(input_shape)
.set_stride(input_strides));
filter_attr_ = graph_.tensor(graph::Tensor_attributes()
.set_name("filter")
.set_data_type(cudnn_type)
.set_dim(filter_shape)
.set_stride(filter_strides));
@@ -75,10 +84,10 @@ class Convolution {
.set_stride(stride)
.set_dilation(dilation);
output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options);
output_attr_->set_output(true);
output_attr_->set_data_type(cudnn_type);
output_attr_->set_dim(output_shape);
output_attr_->set_stride(output_strides);
output_attr_->set_output(true)
.set_data_type(cudnn_type)
.set_dim(output_shape)
.set_stride(output_strides);
CHECK_CUDNN_FE_ERROR(graph_.validate());
CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_));
@@ -93,23 +102,29 @@ class Convolution {
const void* input,
const void* filter,
void* output) {
float alpha = 1;
float beta = 0;
array workspace(
allocator::malloc(workspace_size_),
{static_cast<int>(workspace_size_)},
int8);
encoder.add_temporary(workspace);
std::unordered_map<int64_t, void*> ptr_map{
std::unordered_map<int64_t, void*> variant_pack{
{input_attr_->get_uid(), const_cast<void*>(input)},
{filter_attr_->get_uid(), const_cast<void*>(filter)},
{output_attr_->get_uid(), output}};
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
cudaGraph_t cudnn_cuda_graph;
cudaGraphCreate(&cudnn_cuda_graph, 0);
CHECK_CUDNN_FE_ERROR(graph_.populate_cuda_graph(
handle_, variant_pack, workspace.data<void>(), cudnn_cuda_graph));
encoder.add_graph_node(cudnn_cuda_graph);
cudaGraphDestroy(cudnn_cuda_graph);
#else
auto capture = encoder.capture_context();
CHECK_CUDNN_FE_ERROR(
graph_.execute(handle_, ptr_map, workspace.data<void>()));
graph_.execute(handle_, variant_pack, workspace.data<void>()));
#endif
}
private:
@@ -158,18 +173,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
const auto& in = inputs[0];
const auto& wt = inputs[1];
const array& input = inputs[0];
const array& filter = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
int ndim = in.ndim();
auto [input_shape, input_strides] = cu::swapaxes(in, 1, ndim - 1);
auto [filter_shape, filter_strides] = cu::swapaxes(wt, 1, ndim - 1);
// cuDNN requires dims to be passed as NCHW.
int ndim = input.ndim();
auto [input_shape, input_strides] = cu::swapaxes(input, 1, ndim - 1);
auto [filter_shape, filter_strides] = cu::swapaxes(filter, 1, ndim - 1);
auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1);
cu::Convolution conv(
cu::device(s.device),
in.dtype(),
input.dtype(),
input_shape,
input_strides,
filter_shape,
@@ -181,7 +197,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
std::vector<int64_t>(padding_hi_.begin(), padding_hi_.end()),
std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()),
groups_);
conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>());
conv.run(encoder, input.data<void>(), filter.data<void>(), out.data<void>());
}
} // namespace mlx::core

View File

@@ -86,23 +86,26 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
// Extract and add as single kernel node when possible.
size_t num_nodes;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
if (num_nodes == 1) {
cudaGraphNode_t captured_node;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
CUDA_KERNEL_NODE_PARAMS params;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, &params));
enc.insert_graph_dependencies(GraphNode{node, 'K'});
} else {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
enc.insert_graph_dependencies(GraphNode{node, 'G'});
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
if (type == cudaGraphNodeTypeKernel) {
CUDA_KERNEL_NODE_PARAMS params;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
enc.add_kernel_node(params);
return;
}
}
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
// Otherwise add the captured graph as subgraph.
enc.add_graph_node(graph);
}
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
@@ -236,10 +239,7 @@ void CommandEncoder::add_kernel_node(
kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params;
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
insert_graph_dependencies(GraphNode{node, 'K'});
add_kernel_node(kernel_params);
}
void CommandEncoder::add_kernel_node(
@@ -256,12 +256,27 @@ void CommandEncoder::add_kernel_node(
kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params;
CUgraphNode node;
CHECK_CUDA_ERROR(
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
add_kernel_node(kernel_params);
}
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
}
void CommandEncoder::commit() {
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});

View File

@@ -66,6 +66,11 @@ class CommandEncoder {
void
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}

View File

@@ -4,6 +4,7 @@
#define _USE_MATH_DEFINES
#include <cmath>
#include <iostream>
#include <numeric>
#include "doctest/doctest.h"
@@ -3642,6 +3643,7 @@ TEST_CASE("test conv1d") {
{1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>());
}
@@ -3720,6 +3722,7 @@ TEST_CASE("test conv2d") {
auto expected =
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>());
}