mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Switch to backend apis
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#undef CHECK_CUDA_ERROR
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <cudnn_frontend_find_plan.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
@@ -17,136 +18,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
using namespace cudnn_frontend;
|
||||
|
||||
// The populate_cuda_graph API is supposed to integrate cuDNN with CUDA graph
|
||||
// but it does not seems to support convolution yet.
|
||||
#define CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API 0
|
||||
|
||||
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||
if (cmd.is_bad()) { \
|
||||
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
|
||||
}
|
||||
|
||||
class Convolution {
|
||||
public:
|
||||
Convolution(
|
||||
Device& device,
|
||||
Dtype dtype,
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& input_strides,
|
||||
const std::vector<int64_t>& filter_shape,
|
||||
const std::vector<int64_t>& filter_strides,
|
||||
const std::vector<int64_t>& output_shape,
|
||||
const std::vector<int64_t>& output_strides,
|
||||
const std::vector<int64_t>& stride,
|
||||
const std::vector<int64_t>& padding_lo,
|
||||
const std::vector<int64_t>& padding_hi,
|
||||
const std::vector<int64_t>& dilation)
|
||||
: handle_(device.cudnn_handle()) {
|
||||
auto cudnn_type = dtype_to_cudnn_type(dtype);
|
||||
bool is_half = dtype == float16 || dtype == bfloat16;
|
||||
|
||||
graph_.set_io_data_type(cudnn_type)
|
||||
.set_intermediate_data_type(cudnn_type)
|
||||
.set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type);
|
||||
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
|
||||
graph_.select_behavior_notes(
|
||||
{BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
#endif
|
||||
input_attr_ = graph_.tensor(graph::Tensor_attributes()
|
||||
.set_data_type(cudnn_type)
|
||||
.set_dim(input_shape)
|
||||
.set_stride(input_strides));
|
||||
filter_attr_ = graph_.tensor(graph::Tensor_attributes()
|
||||
.set_data_type(cudnn_type)
|
||||
.set_dim(filter_shape)
|
||||
.set_stride(filter_strides));
|
||||
|
||||
auto conv_options = graph::Conv_fprop_attributes()
|
||||
.set_pre_padding(padding_lo)
|
||||
.set_post_padding(padding_hi)
|
||||
.set_stride(stride)
|
||||
.set_dilation(dilation);
|
||||
output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options);
|
||||
output_attr_->set_output(true)
|
||||
.set_data_type(cudnn_type)
|
||||
.set_dim(output_shape)
|
||||
.set_stride(output_strides);
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph_.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_));
|
||||
CHECK_CUDNN_FE_ERROR(graph_.create_execution_plans({HeurMode_t::A}));
|
||||
CHECK_CUDNN_FE_ERROR(graph_.check_support(handle_));
|
||||
CHECK_CUDNN_FE_ERROR(graph_.build_plans(handle_));
|
||||
CHECK_CUDNN_FE_ERROR(graph_.get_workspace_size(workspace_size_));
|
||||
}
|
||||
|
||||
void run(
|
||||
CommandEncoder& encoder,
|
||||
const void* input,
|
||||
const void* filter,
|
||||
void* output) {
|
||||
array workspace(
|
||||
allocator::malloc(workspace_size_),
|
||||
{static_cast<int>(workspace_size_)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack{
|
||||
{input_attr_->get_uid(), const_cast<void*>(input)},
|
||||
{filter_attr_->get_uid(), const_cast<void*>(filter)},
|
||||
{output_attr_->get_uid(), output}};
|
||||
|
||||
#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API
|
||||
cudaGraph_t cudnn_cuda_graph;
|
||||
cudaGraphCreate(&cudnn_cuda_graph, 0);
|
||||
CHECK_CUDNN_FE_ERROR(graph_.populate_cuda_graph(
|
||||
handle_, variant_pack, workspace.data<void>(), cudnn_cuda_graph));
|
||||
encoder.add_graph_node(cudnn_cuda_graph);
|
||||
cudaGraphDestroy(cudnn_cuda_graph);
|
||||
#else
|
||||
auto capture = encoder.capture_context();
|
||||
CHECK_CUDNN_FE_ERROR(
|
||||
graph_.execute(handle_, variant_pack, workspace.data<void>()));
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
DataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return DataType_t::INT8;
|
||||
case int32:
|
||||
return DataType_t::INT32;
|
||||
case uint8:
|
||||
return DataType_t::UINT8;
|
||||
case float16:
|
||||
return DataType_t::HALF;
|
||||
case bfloat16:
|
||||
return DataType_t::BFLOAT16;
|
||||
case float32:
|
||||
return DataType_t::FLOAT;
|
||||
case float64:
|
||||
return DataType_t::DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
cudnnHandle_t handle_;
|
||||
graph::Graph graph_;
|
||||
std::shared_ptr<graph::Tensor_attributes> input_attr_;
|
||||
std::shared_ptr<graph::Tensor_attributes> filter_attr_;
|
||||
std::shared_ptr<graph::Tensor_attributes> output_attr_;
|
||||
int64_t workspace_size_{0};
|
||||
};
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U>
|
||||
@@ -154,20 +25,158 @@ inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
||||
return std::vector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
auto nhwc_to_nchw(const array& in) {
|
||||
auto shape = convert_vector<int64_t>(in.shape());
|
||||
auto nhwc_to_nchw(const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
shape.erase(shape.end() - 1);
|
||||
auto strides = convert_vector<int64_t>(in.strides());
|
||||
auto strides = convert_vector<int64_t>(x.strides());
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(shape, strides);
|
||||
}
|
||||
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return CUDNN_DATA_INT8;
|
||||
case int32:
|
||||
return CUDNN_DATA_INT32;
|
||||
case uint8:
|
||||
return CUDNN_DATA_UINT8;
|
||||
case float16:
|
||||
return CUDNN_DATA_HALF;
|
||||
case bfloat16:
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
case float32:
|
||||
return CUDNN_DATA_FLOAT;
|
||||
case float64:
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
inline int64_t get_alignment(const array& x) {
|
||||
int64_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||
for (; alignment < 32; alignment *= 2) {
|
||||
if (address % (alignment * 2)) {
|
||||
return alignment;
|
||||
}
|
||||
}
|
||||
return alignment;
|
||||
}
|
||||
|
||||
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
||||
auto [shape, strides] = nhwc_to_nchw(x);
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(shape.size(), shape.data())
|
||||
.setStrides(strides.size(), strides.data())
|
||||
.setId(id)
|
||||
.setAlignment(get_alignment(x))
|
||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||
.build();
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
const cudnnBackendDescriptorType_t& backend_type,
|
||||
cudnn_frontend::OperationGraph& op_graph,
|
||||
bool use_fallback = false) {
|
||||
cudnn_frontend::GeneratorSource source;
|
||||
if (use_fallback) {
|
||||
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setOperation(backend_type)
|
||||
.build();
|
||||
return fallback.getFallbackList();
|
||||
};
|
||||
} else {
|
||||
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||
.build();
|
||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||
};
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||
return generator.generate_engine_config(op_graph);
|
||||
}
|
||||
|
||||
bool execute_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ManagedOpaqueDescriptor& config,
|
||||
const std::string& op_graph_tag,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out) {
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(handle)
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
|
||||
int64_t workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(
|
||||
allocator::malloc(workspace_size),
|
||||
{static_cast<int>(workspace_size)},
|
||||
int8);
|
||||
|
||||
int64_t uids[3] = {'x', 'w', 'y'};
|
||||
void* data_ptrs[3] = {
|
||||
const_cast<void*>(in.data<void>()),
|
||||
const_cast<void*>(wt.data<void>()),
|
||||
out.data<void>(),
|
||||
};
|
||||
|
||||
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(workspace.data<void>())
|
||||
.setDataPointers(3, data_ptrs)
|
||||
.setUids(3, uids)
|
||||
.build();
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool execute_plans(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::EngineConfigList& configs,
|
||||
const std::string& op_graph_tag,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out) {
|
||||
for (auto& config : configs) {
|
||||
try {
|
||||
if (execute_plan(encoder, config, op_graph_tag, in, wt, out)) {
|
||||
return true;
|
||||
}
|
||||
} catch (cudnn_frontend::cudnnException&) {
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
array in = inputs[0];
|
||||
array wt = inputs[1];
|
||||
@@ -176,8 +185,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// While cuDNN supports passing arbitrary strides, it would fail to build a
|
||||
// plan with non-contiguous input.
|
||||
// cuDNN requires contiguous input.
|
||||
// TODO: Handle NCHW format specially.
|
||||
if (!in.flags().row_contiguous) {
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
encoder.add_temporary(in);
|
||||
@@ -191,25 +200,53 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(wt);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// cuDNN requires dims to be passed as NCHW.
|
||||
auto [input_shape, input_strides] = nhwc_to_nchw(in);
|
||||
auto [filter_shape, filter_strides] = nhwc_to_nchw(wt);
|
||||
auto [output_shape, output_strides] = nhwc_to_nchw(out);
|
||||
// TODO: Searching a working execution plan is expensive, add cache.
|
||||
|
||||
cu::Convolution conv(
|
||||
cu::device(s.device),
|
||||
in.dtype(),
|
||||
input_shape,
|
||||
input_strides,
|
||||
filter_shape,
|
||||
filter_strides,
|
||||
output_shape,
|
||||
output_strides,
|
||||
convert_vector<int64_t>(kernel_strides_),
|
||||
convert_vector<int64_t>(padding_lo_),
|
||||
convert_vector<int64_t>(padding_hi_),
|
||||
convert_vector<int64_t>(kernel_dilation_));
|
||||
conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>());
|
||||
// Build operation graph.
|
||||
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
|
||||
? CUDNN_DATA_FLOAT
|
||||
: dtype_to_cudnn_type(in.dtype());
|
||||
|
||||
auto stride = convert_vector<int64_t>(kernel_strides_);
|
||||
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||
auto dilation = convert_vector<int64_t>(kernel_dilation_);
|
||||
|
||||
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||
.setDataType(compute_data_type)
|
||||
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||
.setNDims(stride.size())
|
||||
.setStrides(stride.size(), stride.data())
|
||||
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||
.setDilation(dilation.size(), dilation.data())
|
||||
.build();
|
||||
|
||||
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
|
||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||
.setxDesc(build_tensor('x', in))
|
||||
.setwDesc(build_tensor('w', wt))
|
||||
.setyDesc(build_tensor('y', out))
|
||||
.setcDesc(conv_desc)
|
||||
.build();
|
||||
|
||||
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||
auto op_graph = cudnn_frontend::OperationGraphBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setOperationGraph(ops.size(), ops.data())
|
||||
.build();
|
||||
|
||||
// Try to run plans based on heuristics.
|
||||
auto configs = get_engine_configs(backend_type, op_graph);
|
||||
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
||||
return;
|
||||
}
|
||||
// Then try fallback plans.
|
||||
configs = get_engine_configs(backend_type, op_graph, /* use_fallback */ true);
|
||||
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Unable to find an engine for convolution.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -80,6 +80,7 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.device().make_current();
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
@@ -88,6 +89,9 @@ CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
||||
if (discard) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract and add as single kernel node when possible.
|
||||
size_t num_nodes;
|
||||
|
||||
@@ -22,6 +22,7 @@ class CommandEncoder {
|
||||
~CaptureContext();
|
||||
cudaGraph_t graph;
|
||||
CommandEncoder& enc;
|
||||
bool discard{false};
|
||||
};
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc);
|
||||
@@ -79,6 +80,10 @@ class CommandEncoder {
|
||||
void maybe_commit();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
CudaStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user