Switch to backend apis

This commit is contained in:
Cheng
2025-07-19 23:52:19 +00:00
parent bb6a75bc4a
commit c6076fc77b
3 changed files with 199 additions and 153 deletions

View File

@@ -9,6 +9,7 @@
#undef CHECK_CUDA_ERROR #undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
@@ -17,136 +18,6 @@
namespace mlx::core { namespace mlx::core {
namespace cu {
using namespace cudnn_frontend;
// The populate_cuda_graph API is supposed to integrate cuDNN with CUDA graph
// but it does not seems to support convolution yet.
#define CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API 0
#define CHECK_CUDNN_FE_ERROR(cmd) \
if (cmd.is_bad()) { \
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
}
class Convolution {
public:
Convolution(
Device& device,
Dtype dtype,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& input_strides,
const std::vector<int64_t>& filter_shape,
const std::vector<int64_t>& filter_strides,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& output_strides,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation)
: handle_(device.cudnn_handle()) {
auto cudnn_type = dtype_to_cudnn_type(dtype);
bool is_half = dtype == float16 || dtype == bfloat16;
graph_.set_io_data_type(cudnn_type)
.set_intermediate_data_type(cudnn_type)
.set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type);
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
graph_.select_behavior_notes(
{BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
#endif
input_attr_ = graph_.tensor(graph::Tensor_attributes()
.set_data_type(cudnn_type)
.set_dim(input_shape)
.set_stride(input_strides));
filter_attr_ = graph_.tensor(graph::Tensor_attributes()
.set_data_type(cudnn_type)
.set_dim(filter_shape)
.set_stride(filter_strides));
auto conv_options = graph::Conv_fprop_attributes()
.set_pre_padding(padding_lo)
.set_post_padding(padding_hi)
.set_stride(stride)
.set_dilation(dilation);
output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options);
output_attr_->set_output(true)
.set_data_type(cudnn_type)
.set_dim(output_shape)
.set_stride(output_strides);
CHECK_CUDNN_FE_ERROR(graph_.validate());
CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_));
CHECK_CUDNN_FE_ERROR(graph_.create_execution_plans({HeurMode_t::A}));
CHECK_CUDNN_FE_ERROR(graph_.check_support(handle_));
CHECK_CUDNN_FE_ERROR(graph_.build_plans(handle_));
CHECK_CUDNN_FE_ERROR(graph_.get_workspace_size(workspace_size_));
}
void run(
CommandEncoder& encoder,
const void* input,
const void* filter,
void* output) {
array workspace(
allocator::malloc(workspace_size_),
{static_cast<int>(workspace_size_)},
int8);
encoder.add_temporary(workspace);
std::unordered_map<int64_t, void*> variant_pack{
{input_attr_->get_uid(), const_cast<void*>(input)},
{filter_attr_->get_uid(), const_cast<void*>(filter)},
{output_attr_->get_uid(), output}};
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
cudaGraph_t cudnn_cuda_graph;
cudaGraphCreate(&cudnn_cuda_graph, 0);
CHECK_CUDNN_FE_ERROR(graph_.populate_cuda_graph(
handle_, variant_pack, workspace.data<void>(), cudnn_cuda_graph));
encoder.add_graph_node(cudnn_cuda_graph);
cudaGraphDestroy(cudnn_cuda_graph);
#else
auto capture = encoder.capture_context();
CHECK_CUDNN_FE_ERROR(
graph_.execute(handle_, variant_pack, workspace.data<void>()));
#endif
}
private:
DataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return DataType_t::INT8;
case int32:
return DataType_t::INT32;
case uint8:
return DataType_t::UINT8;
case float16:
return DataType_t::HALF;
case bfloat16:
return DataType_t::BFLOAT16;
case float32:
return DataType_t::FLOAT;
case float64:
return DataType_t::DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
cudnnHandle_t handle_;
graph::Graph graph_;
std::shared_ptr<graph::Tensor_attributes> input_attr_;
std::shared_ptr<graph::Tensor_attributes> filter_attr_;
std::shared_ptr<graph::Tensor_attributes> output_attr_;
int64_t workspace_size_{0};
};
} // namespace cu
namespace { namespace {
template <typename T, typename U> template <typename T, typename U>
@@ -154,20 +25,158 @@ inline std::vector<T> convert_vector(const std::vector<U>& vec) {
return std::vector<T>(vec.begin(), vec.end()); return std::vector<T>(vec.begin(), vec.end());
} }
auto nhwc_to_nchw(const array& in) { auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(in.shape()); auto shape = convert_vector<int64_t>(x.shape());
shape.insert(shape.begin() + 1, shape.back()); shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1); shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(in.strides()); auto strides = convert_vector<int64_t>(x.strides());
strides.insert(strides.begin() + 1, strides.back()); strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1); strides.erase(strides.end() - 1);
return std::make_tuple(shape, strides); return std::make_tuple(shape, strides);
} }
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
inline int64_t get_alignment(const array& x) {
int64_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
cudnn_frontend::EngineConfigList get_engine_configs(
const cudnnBackendDescriptorType_t& backend_type,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = false) {
cudnn_frontend::GeneratorSource source;
if (use_fallback) {
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
};
} else {
source = [](cudnn_frontend::OperationGraph& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
};
}
cudnn_frontend::EngineConfigGenerator generator(1, &source);
return generator.generate_engine_config(op_graph);
}
bool execute_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ManagedOpaqueDescriptor& config,
const std::string& op_graph_tag,
const array& in,
const array& wt,
array& out) {
auto handle = encoder.device().cudnn_handle();
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
int64_t workspace_size = plan.getWorkspaceSize();
array workspace(
allocator::malloc(workspace_size),
{static_cast<int>(workspace_size)},
int8);
int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = {
const_cast<void*>(in.data<void>()),
const_cast<void*>(wt.data<void>()),
out.data<void>(),
};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
encoder.add_temporary(workspace);
return true;
}
bool execute_plans(
cu::CommandEncoder& encoder,
cudnn_frontend::EngineConfigList& configs,
const std::string& op_graph_tag,
const array& in,
const array& wt,
array& out) {
for (auto& config : configs) {
try {
if (execute_plan(encoder, config, op_graph_tag, in, wt, out)) {
return true;
}
} catch (cudnn_frontend::cudnnException&) {
}
}
return false;
}
} // namespace } // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) { void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Convolution::eval_gpu"); nvtx3::scoped_range r("Convolution::eval_gpu");
if (out.size() == 0) {
return;
}
assert(inputs.size() == 2); assert(inputs.size() == 2);
array in = inputs[0]; array in = inputs[0];
array wt = inputs[1]; array wt = inputs[1];
@@ -176,8 +185,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
// While cuDNN supports passing arbitrary strides, it would fail to build a // cuDNN requires contiguous input.
// plan with non-contiguous input. // TODO: Handle NCHW format specially.
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s); in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in); encoder.add_temporary(in);
@@ -191,25 +200,53 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(wt); encoder.set_input_array(wt);
encoder.set_output_array(out); encoder.set_output_array(out);
// cuDNN requires dims to be passed as NCHW. // TODO: Searching a working execution plan is expensive, add cache.
auto [input_shape, input_strides] = nhwc_to_nchw(in);
auto [filter_shape, filter_strides] = nhwc_to_nchw(wt);
auto [output_shape, output_strides] = nhwc_to_nchw(out);
cu::Convolution conv( // Build operation graph.
cu::device(s.device), auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
in.dtype(), ? CUDNN_DATA_FLOAT
input_shape, : dtype_to_cudnn_type(in.dtype());
input_strides,
filter_shape, auto stride = convert_vector<int64_t>(kernel_strides_);
filter_strides, auto padding_lo = convert_vector<int64_t>(padding_lo_);
output_shape, auto padding_hi = convert_vector<int64_t>(padding_hi_);
output_strides, auto dilation = convert_vector<int64_t>(kernel_dilation_);
convert_vector<int64_t>(kernel_strides_),
convert_vector<int64_t>(padding_lo_), auto conv_desc = cudnn_frontend::ConvDescBuilder()
convert_vector<int64_t>(padding_hi_), .setDataType(compute_data_type)
convert_vector<int64_t>(kernel_dilation_)); .setMathMode(CUDNN_CROSS_CORRELATION)
conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>()); .setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', in))
.setwDesc(build_tensor('w', wt))
.setyDesc(build_tensor('y', out))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
// Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, op_graph);
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
return;
}
// Then try fallback plans.
configs = get_engine_configs(backend_type, op_graph, /* use_fallback */ true);
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
return;
}
throw std::runtime_error("Unable to find an engine for convolution.");
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -80,6 +80,7 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
} }
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current();
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
} }
@@ -88,6 +89,9 @@ CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer( std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); }); &graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
if (discard) {
return;
}
// Extract and add as single kernel node when possible. // Extract and add as single kernel node when possible.
size_t num_nodes; size_t num_nodes;

View File

@@ -22,6 +22,7 @@ class CommandEncoder {
~CaptureContext(); ~CaptureContext();
cudaGraph_t graph; cudaGraph_t graph;
CommandEncoder& enc; CommandEncoder& enc;
bool discard{false};
}; };
struct ConcurrentContext { struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc); ConcurrentContext(CommandEncoder& enc);
@@ -79,6 +80,10 @@ class CommandEncoder {
void maybe_commit(); void maybe_commit();
void commit(); void commit();
Device& device() {
return device_;
}
CudaStream& stream() { CudaStream& stream() {
return stream_; return stream_;
} }