mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
Fix recording cudnn conv
This commit is contained in:
@@ -20,6 +20,10 @@ namespace cu {
|
|||||||
|
|
||||||
using namespace cudnn_frontend;
|
using namespace cudnn_frontend;
|
||||||
|
|
||||||
|
// The populate_cuda_graph API is supposed to integrate cuDNN with CUDA graph
|
||||||
|
// but it does not seems to support convolution yet.
|
||||||
|
#define CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API 0
|
||||||
|
|
||||||
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||||
if (cmd.is_bad()) { \
|
if (cmd.is_bad()) { \
|
||||||
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
|
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
|
||||||
@@ -29,7 +33,7 @@ auto swapaxes(const array& in, int axis1, int axis2) {
|
|||||||
std::vector<int> axes(in.ndim());
|
std::vector<int> axes(in.ndim());
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
std::swap(axes[axis1], axes[axis2]);
|
std::swap(axes[axis1], axes[axis2]);
|
||||||
std::vector<int64_t> shape(axes.size());
|
std::vector<int64_t> shape(in.ndim());
|
||||||
std::vector<int64_t> strides(in.ndim());
|
std::vector<int64_t> strides(in.ndim());
|
||||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
||||||
shape[ax] = in.shape()[axes[ax]];
|
shape[ax] = in.shape()[axes[ax]];
|
||||||
@@ -59,13 +63,18 @@ class Convolution {
|
|||||||
bool is_half = dtype == float16 || dtype == bfloat16;
|
bool is_half = dtype == float16 || dtype == bfloat16;
|
||||||
|
|
||||||
graph_.set_io_data_type(cudnn_type)
|
graph_.set_io_data_type(cudnn_type)
|
||||||
|
.set_intermediate_data_type(cudnn_type)
|
||||||
.set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type);
|
.set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type);
|
||||||
|
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
|
||||||
|
graph_.select_behavior_notes(
|
||||||
|
{BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||||
|
#endif
|
||||||
input_attr_ = graph_.tensor(graph::Tensor_attributes()
|
input_attr_ = graph_.tensor(graph::Tensor_attributes()
|
||||||
.set_name("input")
|
.set_data_type(cudnn_type)
|
||||||
.set_dim(input_shape)
|
.set_dim(input_shape)
|
||||||
.set_stride(input_strides));
|
.set_stride(input_strides));
|
||||||
filter_attr_ = graph_.tensor(graph::Tensor_attributes()
|
filter_attr_ = graph_.tensor(graph::Tensor_attributes()
|
||||||
.set_name("filter")
|
.set_data_type(cudnn_type)
|
||||||
.set_dim(filter_shape)
|
.set_dim(filter_shape)
|
||||||
.set_stride(filter_strides));
|
.set_stride(filter_strides));
|
||||||
|
|
||||||
@@ -75,10 +84,10 @@ class Convolution {
|
|||||||
.set_stride(stride)
|
.set_stride(stride)
|
||||||
.set_dilation(dilation);
|
.set_dilation(dilation);
|
||||||
output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options);
|
output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options);
|
||||||
output_attr_->set_output(true);
|
output_attr_->set_output(true)
|
||||||
output_attr_->set_data_type(cudnn_type);
|
.set_data_type(cudnn_type)
|
||||||
output_attr_->set_dim(output_shape);
|
.set_dim(output_shape)
|
||||||
output_attr_->set_stride(output_strides);
|
.set_stride(output_strides);
|
||||||
|
|
||||||
CHECK_CUDNN_FE_ERROR(graph_.validate());
|
CHECK_CUDNN_FE_ERROR(graph_.validate());
|
||||||
CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_));
|
CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_));
|
||||||
@@ -93,23 +102,29 @@ class Convolution {
|
|||||||
const void* input,
|
const void* input,
|
||||||
const void* filter,
|
const void* filter,
|
||||||
void* output) {
|
void* output) {
|
||||||
float alpha = 1;
|
|
||||||
float beta = 0;
|
|
||||||
|
|
||||||
array workspace(
|
array workspace(
|
||||||
allocator::malloc(workspace_size_),
|
allocator::malloc(workspace_size_),
|
||||||
{static_cast<int>(workspace_size_)},
|
{static_cast<int>(workspace_size_)},
|
||||||
int8);
|
int8);
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
|
|
||||||
std::unordered_map<int64_t, void*> ptr_map{
|
std::unordered_map<int64_t, void*> variant_pack{
|
||||||
{input_attr_->get_uid(), const_cast<void*>(input)},
|
{input_attr_->get_uid(), const_cast<void*>(input)},
|
||||||
{filter_attr_->get_uid(), const_cast<void*>(filter)},
|
{filter_attr_->get_uid(), const_cast<void*>(filter)},
|
||||||
{output_attr_->get_uid(), output}};
|
{output_attr_->get_uid(), output}};
|
||||||
|
|
||||||
|
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
|
||||||
|
cudaGraph_t cudnn_cuda_graph;
|
||||||
|
cudaGraphCreate(&cudnn_cuda_graph, 0);
|
||||||
|
CHECK_CUDNN_FE_ERROR(graph_.populate_cuda_graph(
|
||||||
|
handle_, variant_pack, workspace.data<void>(), cudnn_cuda_graph));
|
||||||
|
encoder.add_graph_node(cudnn_cuda_graph);
|
||||||
|
cudaGraphDestroy(cudnn_cuda_graph);
|
||||||
|
#else
|
||||||
auto capture = encoder.capture_context();
|
auto capture = encoder.capture_context();
|
||||||
CHECK_CUDNN_FE_ERROR(
|
CHECK_CUDNN_FE_ERROR(
|
||||||
graph_.execute(handle_, ptr_map, workspace.data<void>()));
|
graph_.execute(handle_, variant_pack, workspace.data<void>()));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -158,18 +173,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
const auto& in = inputs[0];
|
const array& input = inputs[0];
|
||||||
const auto& wt = inputs[1];
|
const array& filter = inputs[1];
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
int ndim = in.ndim();
|
// cuDNN requires dims to be passed as NCHW.
|
||||||
auto [input_shape, input_strides] = cu::swapaxes(in, 1, ndim - 1);
|
int ndim = input.ndim();
|
||||||
auto [filter_shape, filter_strides] = cu::swapaxes(wt, 1, ndim - 1);
|
auto [input_shape, input_strides] = cu::swapaxes(input, 1, ndim - 1);
|
||||||
|
auto [filter_shape, filter_strides] = cu::swapaxes(filter, 1, ndim - 1);
|
||||||
auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1);
|
auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1);
|
||||||
|
|
||||||
cu::Convolution conv(
|
cu::Convolution conv(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
in.dtype(),
|
input.dtype(),
|
||||||
input_shape,
|
input_shape,
|
||||||
input_strides,
|
input_strides,
|
||||||
filter_shape,
|
filter_shape,
|
||||||
@@ -181,7 +197,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
std::vector<int64_t>(padding_hi_.begin(), padding_hi_.end()),
|
std::vector<int64_t>(padding_hi_.begin(), padding_hi_.end()),
|
||||||
std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()),
|
std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()),
|
||||||
groups_);
|
groups_);
|
||||||
conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>());
|
conv.run(encoder, input.data<void>(), filter.data<void>(), out.data<void>());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -86,23 +86,26 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
|||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||||
|
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||||
|
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
||||||
|
|
||||||
|
// Extract and add as single kernel node when possible.
|
||||||
size_t num_nodes;
|
size_t num_nodes;
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
||||||
if (num_nodes == 1) {
|
if (num_nodes == 1) {
|
||||||
cudaGraphNode_t captured_node;
|
cudaGraphNode_t captured_node;
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
||||||
CUDA_KERNEL_NODE_PARAMS params;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
|
||||||
cudaGraphNode_t node;
|
if (type == cudaGraphNodeTypeKernel) {
|
||||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms));
|
CUDA_KERNEL_NODE_PARAMS params;
|
||||||
enc.insert_graph_dependencies(GraphNode{node, 'K'});
|
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||||
} else {
|
enc.add_kernel_node(params);
|
||||||
cudaGraphNode_t node;
|
return;
|
||||||
CHECK_CUDA_ERROR(
|
}
|
||||||
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
|
|
||||||
enc.insert_graph_dependencies(GraphNode{node, 'G'});
|
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
|
// Otherwise add the captured graph as subgraph.
|
||||||
|
enc.add_graph_node(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||||
@@ -236,10 +239,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
kernel_params.blockDim = block_dim;
|
kernel_params.blockDim = block_dim;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
cudaGraphNode_t node;
|
add_kernel_node(kernel_params);
|
||||||
CHECK_CUDA_ERROR(
|
|
||||||
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_kernel_node(
|
void CommandEncoder::add_kernel_node(
|
||||||
@@ -256,12 +256,27 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.blockDimY = block_dim.y;
|
kernel_params.blockDimY = block_dim.y;
|
||||||
kernel_params.blockDimZ = block_dim.z;
|
kernel_params.blockDimZ = block_dim.z;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
CUgraphNode node;
|
add_kernel_node(kernel_params);
|
||||||
CHECK_CUDA_ERROR(
|
}
|
||||||
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
|
||||||
|
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||||
|
cudaGraphNode_t node;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||||
|
CUgraphNode node;
|
||||||
|
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
|
cudaGraphNode_t node;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
|
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||||
|
}
|
||||||
|
|
||||||
void CommandEncoder::commit() {
|
void CommandEncoder::commit() {
|
||||||
if (!temporaries_.empty()) {
|
if (!temporaries_.empty()) {
|
||||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||||
|
@@ -66,6 +66,11 @@ class CommandEncoder {
|
|||||||
void
|
void
|
||||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
||||||
|
|
||||||
|
// Low-level graph helpers.
|
||||||
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||||
|
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||||
|
void add_graph_node(cudaGraph_t child);
|
||||||
|
|
||||||
void add_temporary(const array& arr) {
|
void add_temporary(const array& arr) {
|
||||||
temporaries_.push_back(arr.data_shared_ptr());
|
temporaries_.push_back(arr.data_shared_ptr());
|
||||||
}
|
}
|
||||||
|
@@ -4,6 +4,7 @@
|
|||||||
#define _USE_MATH_DEFINES
|
#define _USE_MATH_DEFINES
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <iostream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
@@ -3642,6 +3643,7 @@ TEST_CASE("test conv1d") {
|
|||||||
{1, 3, 4});
|
{1, 3, 4});
|
||||||
|
|
||||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||||
|
std::cout << out << std::endl;
|
||||||
CHECK(allclose(out, expected).item<bool>());
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3720,6 +3722,7 @@ TEST_CASE("test conv2d") {
|
|||||||
auto expected =
|
auto expected =
|
||||||
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
||||||
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||||
|
std::cout << out << std::endl;
|
||||||
CHECK(allclose(out, expected).item<bool>());
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user