mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fix recording cudnn conv
This commit is contained in:
@@ -20,6 +20,10 @@ namespace cu {
|
||||
|
||||
using namespace cudnn_frontend;
|
||||
|
||||
// The populate_cuda_graph API is supposed to integrate cuDNN with CUDA graph
|
||||
// but it does not seems to support convolution yet.
|
||||
#define CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API 0
|
||||
|
||||
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||
if (cmd.is_bad()) { \
|
||||
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
|
||||
@@ -29,7 +33,7 @@ auto swapaxes(const array& in, int axis1, int axis2) {
|
||||
std::vector<int> axes(in.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[axis1], axes[axis2]);
|
||||
std::vector<int64_t> shape(axes.size());
|
||||
std::vector<int64_t> shape(in.ndim());
|
||||
std::vector<int64_t> strides(in.ndim());
|
||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
||||
shape[ax] = in.shape()[axes[ax]];
|
||||
@@ -59,13 +63,18 @@ class Convolution {
|
||||
bool is_half = dtype == float16 || dtype == bfloat16;
|
||||
|
||||
graph_.set_io_data_type(cudnn_type)
|
||||
.set_intermediate_data_type(cudnn_type)
|
||||
.set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type);
|
||||
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
|
||||
graph_.select_behavior_notes(
|
||||
{BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
#endif
|
||||
input_attr_ = graph_.tensor(graph::Tensor_attributes()
|
||||
.set_name("input")
|
||||
.set_data_type(cudnn_type)
|
||||
.set_dim(input_shape)
|
||||
.set_stride(input_strides));
|
||||
filter_attr_ = graph_.tensor(graph::Tensor_attributes()
|
||||
.set_name("filter")
|
||||
.set_data_type(cudnn_type)
|
||||
.set_dim(filter_shape)
|
||||
.set_stride(filter_strides));
|
||||
|
||||
@@ -75,10 +84,10 @@ class Convolution {
|
||||
.set_stride(stride)
|
||||
.set_dilation(dilation);
|
||||
output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options);
|
||||
output_attr_->set_output(true);
|
||||
output_attr_->set_data_type(cudnn_type);
|
||||
output_attr_->set_dim(output_shape);
|
||||
output_attr_->set_stride(output_strides);
|
||||
output_attr_->set_output(true)
|
||||
.set_data_type(cudnn_type)
|
||||
.set_dim(output_shape)
|
||||
.set_stride(output_strides);
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph_.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_));
|
||||
@@ -93,23 +102,29 @@ class Convolution {
|
||||
const void* input,
|
||||
const void* filter,
|
||||
void* output) {
|
||||
float alpha = 1;
|
||||
float beta = 0;
|
||||
|
||||
array workspace(
|
||||
allocator::malloc(workspace_size_),
|
||||
{static_cast<int>(workspace_size_)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
|
||||
std::unordered_map<int64_t, void*> ptr_map{
|
||||
std::unordered_map<int64_t, void*> variant_pack{
|
||||
{input_attr_->get_uid(), const_cast<void*>(input)},
|
||||
{filter_attr_->get_uid(), const_cast<void*>(filter)},
|
||||
{output_attr_->get_uid(), output}};
|
||||
|
||||
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
|
||||
cudaGraph_t cudnn_cuda_graph;
|
||||
cudaGraphCreate(&cudnn_cuda_graph, 0);
|
||||
CHECK_CUDNN_FE_ERROR(graph_.populate_cuda_graph(
|
||||
handle_, variant_pack, workspace.data<void>(), cudnn_cuda_graph));
|
||||
encoder.add_graph_node(cudnn_cuda_graph);
|
||||
cudaGraphDestroy(cudnn_cuda_graph);
|
||||
#else
|
||||
auto capture = encoder.capture_context();
|
||||
CHECK_CUDNN_FE_ERROR(
|
||||
graph_.execute(handle_, ptr_map, workspace.data<void>()));
|
||||
graph_.execute(handle_, variant_pack, workspace.data<void>()));
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -158,18 +173,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
const auto& in = inputs[0];
|
||||
const auto& wt = inputs[1];
|
||||
const array& input = inputs[0];
|
||||
const array& filter = inputs[1];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
int ndim = in.ndim();
|
||||
auto [input_shape, input_strides] = cu::swapaxes(in, 1, ndim - 1);
|
||||
auto [filter_shape, filter_strides] = cu::swapaxes(wt, 1, ndim - 1);
|
||||
// cuDNN requires dims to be passed as NCHW.
|
||||
int ndim = input.ndim();
|
||||
auto [input_shape, input_strides] = cu::swapaxes(input, 1, ndim - 1);
|
||||
auto [filter_shape, filter_strides] = cu::swapaxes(filter, 1, ndim - 1);
|
||||
auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1);
|
||||
|
||||
cu::Convolution conv(
|
||||
cu::device(s.device),
|
||||
in.dtype(),
|
||||
input.dtype(),
|
||||
input_shape,
|
||||
input_strides,
|
||||
filter_shape,
|
||||
@@ -181,7 +197,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int64_t>(padding_hi_.begin(), padding_hi_.end()),
|
||||
std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()),
|
||||
groups_);
|
||||
conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>());
|
||||
conv.run(encoder, input.data<void>(), filter.data<void>(), out.data<void>());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -86,23 +86,26 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
||||
|
||||
// Extract and add as single kernel node when possible.
|
||||
size_t num_nodes;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
||||
if (num_nodes == 1) {
|
||||
cudaGraphNode_t captured_node;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
||||
CUDA_KERNEL_NODE_PARAMS params;
|
||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms));
|
||||
enc.insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
} else {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
|
||||
enc.insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
|
||||
if (type == cudaGraphNodeTypeKernel) {
|
||||
CUDA_KERNEL_NODE_PARAMS params;
|
||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||
enc.add_kernel_node(params);
|
||||
return;
|
||||
}
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
|
||||
// Otherwise add the captured graph as subgraph.
|
||||
enc.add_graph_node(graph);
|
||||
}
|
||||
|
||||
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||
@@ -236,10 +239,7 @@ void CommandEncoder::add_kernel_node(
|
||||
kernel_params.gridDim = grid_dim;
|
||||
kernel_params.blockDim = block_dim;
|
||||
kernel_params.kernelParams = params;
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
add_kernel_node(kernel_params);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(
|
||||
@@ -256,12 +256,27 @@ void CommandEncoder::add_kernel_node(
|
||||
kernel_params.blockDimY = block_dim.y;
|
||||
kernel_params.blockDimZ = block_dim.z;
|
||||
kernel_params.kernelParams = params;
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||
add_kernel_node(kernel_params);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
|
@@ -66,6 +66,11 @@ class CommandEncoder {
|
||||
void
|
||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
||||
|
||||
// Low-level graph helpers.
|
||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||
void add_graph_node(cudaGraph_t child);
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
@@ -3642,6 +3643,7 @@ TEST_CASE("test conv1d") {
|
||||
{1, 3, 4});
|
||||
|
||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||
std::cout << out << std::endl;
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
@@ -3720,6 +3722,7 @@ TEST_CASE("test conv2d") {
|
||||
auto expected =
|
||||
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
||||
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||
std::cout << out << std::endl;
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user