diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index fc7b0fee5..4a2917254 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -9,6 +9,7 @@ #undef CHECK_CUDA_ERROR #include +#include #include #include @@ -17,136 +18,6 @@ namespace mlx::core { -namespace cu { - -using namespace cudnn_frontend; - -// The populate_cuda_graph API is supposed to integrate cuDNN with CUDA graph -// but it does not seems to support convolution yet. -#define CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API 0 - -#define CHECK_CUDNN_FE_ERROR(cmd) \ - if (cmd.is_bad()) { \ - throw std::runtime_error(fmt::format("{} failed.", #cmd)); \ - } - -class Convolution { - public: - Convolution( - Device& device, - Dtype dtype, - const std::vector& input_shape, - const std::vector& input_strides, - const std::vector& filter_shape, - const std::vector& filter_strides, - const std::vector& output_shape, - const std::vector& output_strides, - const std::vector& stride, - const std::vector& padding_lo, - const std::vector& padding_hi, - const std::vector& dilation) - : handle_(device.cudnn_handle()) { - auto cudnn_type = dtype_to_cudnn_type(dtype); - bool is_half = dtype == float16 || dtype == bfloat16; - - graph_.set_io_data_type(cudnn_type) - .set_intermediate_data_type(cudnn_type) - .set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type); -#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API - graph_.select_behavior_notes( - {BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); -#endif - input_attr_ = graph_.tensor(graph::Tensor_attributes() - .set_data_type(cudnn_type) - .set_dim(input_shape) - .set_stride(input_strides)); - filter_attr_ = graph_.tensor(graph::Tensor_attributes() - .set_data_type(cudnn_type) - .set_dim(filter_shape) - .set_stride(filter_strides)); - - auto conv_options = graph::Conv_fprop_attributes() - .set_pre_padding(padding_lo) - .set_post_padding(padding_hi) - .set_stride(stride) - .set_dilation(dilation); - output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options); - output_attr_->set_output(true) - .set_data_type(cudnn_type) - .set_dim(output_shape) - .set_stride(output_strides); - - CHECK_CUDNN_FE_ERROR(graph_.validate()); - CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_)); - CHECK_CUDNN_FE_ERROR(graph_.create_execution_plans({HeurMode_t::A})); - CHECK_CUDNN_FE_ERROR(graph_.check_support(handle_)); - CHECK_CUDNN_FE_ERROR(graph_.build_plans(handle_)); - CHECK_CUDNN_FE_ERROR(graph_.get_workspace_size(workspace_size_)); - } - - void run( - CommandEncoder& encoder, - const void* input, - const void* filter, - void* output) { - array workspace( - allocator::malloc(workspace_size_), - {static_cast(workspace_size_)}, - int8); - encoder.add_temporary(workspace); - - std::unordered_map variant_pack{ - {input_attr_->get_uid(), const_cast(input)}, - {filter_attr_->get_uid(), const_cast(filter)}, - {output_attr_->get_uid(), output}}; - -#if CUDNN_CONV_SUPPORTS_CUDA_GRAPH_NATIVE_API - cudaGraph_t cudnn_cuda_graph; - cudaGraphCreate(&cudnn_cuda_graph, 0); - CHECK_CUDNN_FE_ERROR(graph_.populate_cuda_graph( - handle_, variant_pack, workspace.data(), cudnn_cuda_graph)); - encoder.add_graph_node(cudnn_cuda_graph); - cudaGraphDestroy(cudnn_cuda_graph); -#else - auto capture = encoder.capture_context(); - CHECK_CUDNN_FE_ERROR( - graph_.execute(handle_, variant_pack, workspace.data())); -#endif - } - - private: - DataType_t dtype_to_cudnn_type(Dtype dtype) { - switch (dtype) { - case int8: - return DataType_t::INT8; - case int32: - return DataType_t::INT32; - case uint8: - return DataType_t::UINT8; - case float16: - return DataType_t::HALF; - case bfloat16: - return DataType_t::BFLOAT16; - case float32: - return DataType_t::FLOAT; - case float64: - return DataType_t::DOUBLE; - default: - throw std::runtime_error(fmt::format( - "Unsupported dtype in Convolution: {}.", dtype_to_string(dtype))); - } - } - - cudnnHandle_t handle_; - graph::Graph graph_; - std::shared_ptr input_attr_; - std::shared_ptr filter_attr_; - std::shared_ptr output_attr_; - int64_t workspace_size_{0}; -}; - -} // namespace cu - namespace { template @@ -154,20 +25,158 @@ inline std::vector convert_vector(const std::vector& vec) { return std::vector(vec.begin(), vec.end()); } -auto nhwc_to_nchw(const array& in) { - auto shape = convert_vector(in.shape()); +auto nhwc_to_nchw(const array& x) { + auto shape = convert_vector(x.shape()); shape.insert(shape.begin() + 1, shape.back()); shape.erase(shape.end() - 1); - auto strides = convert_vector(in.strides()); + auto strides = convert_vector(x.strides()); strides.insert(strides.begin() + 1, strides.back()); strides.erase(strides.end() - 1); return std::make_tuple(shape, strides); } +inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) { + switch (dtype) { + case int8: + return CUDNN_DATA_INT8; + case int32: + return CUDNN_DATA_INT32; + case uint8: + return CUDNN_DATA_UINT8; + case float16: + return CUDNN_DATA_HALF; + case bfloat16: + return CUDNN_DATA_BFLOAT16; + case float32: + return CUDNN_DATA_FLOAT; + case float64: + return CUDNN_DATA_DOUBLE; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in Convolution: {}.", dtype_to_string(dtype))); + } +} + +inline int64_t get_alignment(const array& x) { + int64_t alignment = 1; + uintptr_t address = reinterpret_cast(x.data()); + for (; alignment < 32; alignment *= 2) { + if (address % (alignment * 2)) { + return alignment; + } + } + return alignment; +} + +inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) { + auto [shape, strides] = nhwc_to_nchw(x); + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(get_alignment(x)) + .setDataType(dtype_to_cudnn_type(x.dtype())) + .build(); +} + +cudnn_frontend::EngineConfigList get_engine_configs( + const cudnnBackendDescriptorType_t& backend_type, + cudnn_frontend::OperationGraph& op_graph, + bool use_fallback = false) { + cudnn_frontend::GeneratorSource source; + if (use_fallback) { + source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) { + auto fallback = cudnn_frontend::EngineFallbackListBuilder() + .setOperationGraph(op_graph) + .setOperation(backend_type) + .build(); + return fallback.getFallbackList(); + }; + } else { + source = [](cudnn_frontend::OperationGraph& op_graph) { + auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() + .setOperationGraph(op_graph) + .setHeurMode(CUDNN_HEUR_MODE_A) + .build(); + return heuristics.getEngineConfig(heuristics.getEngineConfigCount()); + }; + } + + cudnn_frontend::EngineConfigGenerator generator(1, &source); + return generator.generate_engine_config(op_graph); +} + +bool execute_plan( + cu::CommandEncoder& encoder, + cudnn_frontend::ManagedOpaqueDescriptor& config, + const std::string& op_graph_tag, + const array& in, + const array& wt, + array& out) { + auto handle = encoder.device().cudnn_handle(); + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(config, op_graph_tag) + .build(); + + int64_t workspace_size = plan.getWorkspaceSize(); + array workspace( + allocator::malloc(workspace_size), + {static_cast(workspace_size)}, + int8); + + int64_t uids[3] = {'x', 'w', 'y'}; + void* data_ptrs[3] = { + const_cast(in.data()), + const_cast(wt.data()), + out.data(), + }; + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace.data()) + .setDataPointers(3, data_ptrs) + .setUids(3, uids) + .build(); + + auto capture = encoder.capture_context(); + if (cudnnBackendExecute( + handle, plan.get_raw_desc(), variantPack.get_raw_desc()) != + CUDNN_STATUS_SUCCESS) { + // Discard the captured graph when failed. + capture.discard = true; + return false; + } + + encoder.add_temporary(workspace); + return true; +} + +bool execute_plans( + cu::CommandEncoder& encoder, + cudnn_frontend::EngineConfigList& configs, + const std::string& op_graph_tag, + const array& in, + const array& wt, + array& out) { + for (auto& config : configs) { + try { + if (execute_plan(encoder, config, op_graph_tag, in, wt, out)) { + return true; + } + } catch (cudnn_frontend::cudnnException&) { + } + } + return false; +} + } // namespace void Convolution::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Convolution::eval_gpu"); + if (out.size() == 0) { + return; + } + assert(inputs.size() == 2); array in = inputs[0]; array wt = inputs[1]; @@ -176,8 +185,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - // While cuDNN supports passing arbitrary strides, it would fail to build a - // plan with non-contiguous input. + // cuDNN requires contiguous input. + // TODO: Handle NCHW format specially. if (!in.flags().row_contiguous) { in = contiguous_copy_gpu(in, s); encoder.add_temporary(in); @@ -191,25 +200,53 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(wt); encoder.set_output_array(out); - // cuDNN requires dims to be passed as NCHW. - auto [input_shape, input_strides] = nhwc_to_nchw(in); - auto [filter_shape, filter_strides] = nhwc_to_nchw(wt); - auto [output_shape, output_strides] = nhwc_to_nchw(out); + // TODO: Searching a working execution plan is expensive, add cache. - cu::Convolution conv( - cu::device(s.device), - in.dtype(), - input_shape, - input_strides, - filter_shape, - filter_strides, - output_shape, - output_strides, - convert_vector(kernel_strides_), - convert_vector(padding_lo_), - convert_vector(padding_hi_), - convert_vector(kernel_dilation_)); - conv.run(encoder, in.data(), wt.data(), out.data()); + // Build operation graph. + auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16) + ? CUDNN_DATA_FLOAT + : dtype_to_cudnn_type(in.dtype()); + + auto stride = convert_vector(kernel_strides_); + auto padding_lo = convert_vector(padding_lo_); + auto padding_hi = convert_vector(padding_hi_); + auto dilation = convert_vector(kernel_dilation_); + + auto conv_desc = cudnn_frontend::ConvDescBuilder() + .setDataType(compute_data_type) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(stride.size()) + .setStrides(stride.size(), stride.data()) + .setPrePadding(padding_lo.size(), padding_lo.data()) + .setPostPadding(padding_hi.size(), padding_hi.data()) + .setDilation(dilation.size(), dilation.data()) + .build(); + + auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; + auto op = cudnn_frontend::OperationBuilder(backend_type) + .setxDesc(build_tensor('x', in)) + .setwDesc(build_tensor('w', wt)) + .setyDesc(build_tensor('y', out)) + .setcDesc(conv_desc) + .build(); + + std::array ops = {&op}; + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(encoder.device().cudnn_handle()) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Try to run plans based on heuristics. + auto configs = get_engine_configs(backend_type, op_graph); + if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) { + return; + } + // Then try fallback plans. + configs = get_engine_configs(backend_type, op_graph, /* use_fallback */ true); + if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) { + return; + } + throw std::runtime_error("Unable to find an engine for convolution."); } } // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index a9066cdd4..2e3a0b7c9 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -80,6 +80,7 @@ CommandEncoder& Device::get_command_encoder(Stream s) { } CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { + enc.device().make_current(); CHECK_CUDA_ERROR( cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); } @@ -88,6 +89,9 @@ CommandEncoder::CaptureContext::~CaptureContext() { CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); std::unique_ptr graph_freer( &graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); }); + if (discard) { + return; + } // Extract and add as single kernel node when possible. size_t num_nodes; diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index e3456f7ed..25c26fb0d 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -22,6 +22,7 @@ class CommandEncoder { ~CaptureContext(); cudaGraph_t graph; CommandEncoder& enc; + bool discard{false}; }; struct ConcurrentContext { ConcurrentContext(CommandEncoder& enc); @@ -79,6 +80,10 @@ class CommandEncoder { void maybe_commit(); void commit(); + Device& device() { + return device_; + } + CudaStream& stream() { return stream_; }