From 6e762fe2e2c07886fe3fd47c6f22df1a8cdd8c66 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 2 Dec 2025 07:55:43 +0900 Subject: [PATCH] [CUDA] Migrate conv code to new cuDNN APIs (#2847) --- mlx/backend/cuda/conv.cpp | 190 +++++----- mlx/backend/cuda/cuda_utils.h | 3 + mlx/backend/cuda/cudnn_utils.cpp | 335 ++++++------------ mlx/backend/cuda/cudnn_utils.h | 216 +++++------ mlx/backend/cuda/device.cpp | 9 - .../cuda/scaled_dot_product_attention.cpp | 207 +++-------- mlx/backend/cuda/utils.cpp | 7 + mlx/backend/cuda/utils.h | 4 +- 8 files changed, 354 insertions(+), 617 deletions(-) diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 4f45d6c38..6d01da154 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -15,19 +15,16 @@ namespace mlx::core { namespace { -// Alias for better readability. -#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR -#define CONV_BACKWARD_INPUT \ - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR -#define CONV_BACKWARD_WEIGHT \ - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR - -// Custom placeholder representing fallback kernel. -#define CONV_FALLBACK static_cast(-1) +enum ConvBackendType { + CONV_FALLBACK, + CONV_FORWARD, + CONV_BACKWARD_INPUT, + CONV_BACKWARD_WEIGHT, +}; struct ConvCacheKey { int device_id; - cudnnDataType_t cudnn_dtype; + fe::DataType_t cudnn_dtype; std::array input_shape; std::array weight_shape; std::array stride; @@ -44,15 +41,13 @@ struct ConvCacheKey { auto& conv_cache() { static LRUBytesKeyCache< ConvCacheKey, - std::pair< - cudnnBackendDescriptorType_t, - std::optional>> + std::pair>> cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128); return cache; } -auto get_conv_op_settings( - cudnnBackendDescriptorType_t backend_type, +auto get_conv_settings( + ConvBackendType backend_type, array& x, array& w, array& y, @@ -68,8 +63,8 @@ auto get_conv_op_settings( for (int i = 0; i < padding_lo.size(); ++i) { int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1); padding_lo[i] = wt_size - padding_lo[i] - 1; - int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1); - int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1); + int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1); + int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1); padding_hi[i] = out_size - in_size + padding_hi[i]; } return std::make_tuple( @@ -95,49 +90,57 @@ auto get_conv_op_settings( } } -std::optional build_conv_op_graph( +std::optional build_conv_graph( cu::CommandEncoder& encoder, - cudnnBackendDescriptorType_t backend_type, + ConvBackendType backend_type, Dtype dtype, array& x, array& w, array& y, - const SmallVector& stride, - const SmallVector& padding_lo, - const SmallVector& padding_hi, - const SmallVector& dilation) { - try { - auto compute_dtype = (dtype == float16 || dtype == bfloat16) - ? CUDNN_DATA_FLOAT - : dtype_to_cudnn_type(dtype); - auto conv_desc = cudnn_frontend::ConvDescBuilder() - .setDataType(compute_dtype) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(stride.size()) - .setStrides(stride.size(), stride.data()) - .setPrePadding(padding_lo.size(), padding_lo.data()) - .setPostPadding(padding_hi.size(), padding_hi.data()) - .setDilation(dilation.size(), dilation.data()) - .build(); + const std::vector& stride, + const std::vector& padding_lo, + const std::vector& padding_hi, + const std::vector& dilation) { + auto compute_dtype = + (dtype == float16 || dtype == bfloat16) ? float32 : dtype; + DnnGraph graph(encoder.device().cudnn_handle(), dtype, compute_dtype); + auto x_ = graph.tensor_nchw("X", 'x', x); + auto w_ = graph.tensor_nchw("W", 'w', w); - auto op = cudnn_frontend::OperationBuilder(backend_type) - .setxDesc(build_cudnn_tensor_nchw('x', x)) - .setwDesc(build_cudnn_tensor_nchw('w', w)) - .setyDesc(build_cudnn_tensor_nchw('y', y)) - .setcDesc(conv_desc) - .build(); + auto set_options = [&](auto& options) { + options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype)) + .set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION) + .set_stride(stride) + .set_pre_padding(padding_lo) + .set_post_padding(padding_hi) + .set_dilation(dilation); + }; - std::array ops = {&op}; - return cudnn_frontend::OperationGraphBuilder() - .setHandle(encoder.device().cudnn_handle()) - .setOperationGraph(ops.size(), ops.data()) - .build(); - } catch (cudnn_frontend::cudnnException& error) { - if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) { - throw; - } + std::shared_ptr y_; + if (backend_type == CONV_FORWARD) { + auto options = fe::graph::Conv_fprop_attributes(); + set_options(options); + y_ = graph.conv_fprop(x_, w_, options); + } else if (backend_type == CONV_BACKWARD_INPUT) { + auto options = fe::graph::Conv_dgrad_attributes(); + set_options(options); + y_ = graph.conv_dgrad(x_, w_, options); + } else if (backend_type == CONV_BACKWARD_WEIGHT) { + auto options = fe::graph::Conv_wgrad_attributes(); + set_options(options); + y_ = graph.conv_wgrad(w_, x_, options); + } + graph.tensor_nchw(y_, 'y', y)->set_output(true); + + if (graph.prepare().is_bad()) { return std::nullopt; } + graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS}); + if (dtype == float32 && !env::enable_tf32()) { + graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE}); + } + CHECK_CUDNN_FE_ERROR(graph.build()); + return graph; } // Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups). @@ -181,7 +184,7 @@ array group_transpose( // eval_gpu, with cost of possible redundant copies. std::tuple prepare_args( cu::CommandEncoder& encoder, - cudnnBackendDescriptorType_t backend_type, + ConvBackendType backend_type, array in, array wt, array out, @@ -221,27 +224,11 @@ std::tuple prepare_args( return {std::move(in), std::move(wt), std::move(out)}; } -// Get the x/w/y args from the in/wt/out args depending on backend type. -inline std::tuple dispatch_args( - cudnnBackendDescriptorType_t backend_type, - array& in, - array& wt, - array& out) { - switch (backend_type) { - case CONV_BACKWARD_INPUT: - return {out, wt, in}; - case CONV_BACKWARD_WEIGHT: - return {in, out, wt}; - default: - return {in, wt, out}; - } -} - // Register inputs and outputs before actually running conv op. Can only be // called once per eval_gpu. void register_args( cu::CommandEncoder& encoder, - cudnnBackendDescriptorType_t backend_type, + ConvBackendType backend_type, array& in, array& wt, array& intermediate_out, @@ -297,16 +284,19 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { get_alignment(wt), get_alignment(out)}; if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { - auto& [backend_type, plan] = it->second; - if (plan) { - // Run cached plan. + auto& [backend_type, graph] = it->second; + if (graph) { + // Run cached graph. std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, groups_, s); register_args(encoder, backend_type, in, wt, out, out_); - auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { - throw std::runtime_error("[conv] Cached plan failed to execute."); - } + CHECK_CUDNN_FE_ERROR(graph->encode_capturing( + encoder, + { + {'x', gpu_ptr(in)}, + {'w', gpu_ptr(wt)}, + {'y', gpu_ptr(out)}, + })); } else { // Run fallback kernel. gemm_conv( @@ -327,7 +317,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { // There is no reliable way to deduce the proper cuDNN backend for the // convolution, so we make a best guess and then try. - SmallVector try_backends; + SmallVector try_backends; if (flip_) { // When weight is flipped, we assume it is backward input convolution. try_backends.push_back(CONV_BACKWARD_INPUT); @@ -345,13 +335,12 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { } // Try to build op graph. - cudnnBackendDescriptorType_t backend_type; - std::optional op_graph; + ConvBackendType backend_type; + std::optional graph; for (auto try_backend : try_backends) { - auto [in_copy, wt_copy, out_copy] = + auto [x, w, y] = prepare_args(encoder, try_backend, in, wt, out, groups_, s); - auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy); - auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings( + auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings( try_backend, x, w, @@ -361,7 +350,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { padding_hi_, kernel_dilation_, input_dilation_); - op_graph = build_conv_op_graph( + graph = build_conv_graph( encoder, try_backend, dtype, @@ -372,30 +361,27 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { padding_lo, padding_hi, dilation); - if (op_graph) { + if (graph) { backend_type = try_backend; - in = std::move(in_copy); - wt = std::move(wt_copy); - out = std::move(out_copy); + in = std::move(x); + wt = std::move(w); + out = std::move(y); break; } } - if (op_graph) { - // Find a plan for the graph and execute it. - auto plan = find_cudnn_plan_from_op_graph( - encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); - if (plan) { - // Setup inputs and outputs. - register_args(encoder, backend_type, in, wt, out, out_); - - auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { - conv_cache().emplace( - cache_key, std::make_pair(backend_type, std::move(*plan))); - return; - } - } + if (graph) { + register_args(encoder, backend_type, in, wt, out, out_); + CHECK_CUDNN_FE_ERROR(graph->encode_capturing( + encoder, + { + {'x', gpu_ptr(in)}, + {'w', gpu_ptr(wt)}, + {'y', gpu_ptr(out)}, + })); + conv_cache().emplace( + cache_key, std::make_pair(backend_type, std::move(*graph))); + return; } // Use fallback kernel for settings not supported by cuDNN. diff --git a/mlx/backend/cuda/cuda_utils.h b/mlx/backend/cuda/cuda_utils.h index 812acba07..d9d557630 100644 --- a/mlx/backend/cuda/cuda_utils.h +++ b/mlx/backend/cuda/cuda_utils.h @@ -5,6 +5,7 @@ #include #include #include +#include namespace mlx::core { @@ -12,10 +13,12 @@ namespace mlx::core { void check_cublas_error(const char* name, cublasStatus_t err); void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, CUresult err); +void check_cudnn_error(const char* name, cudnnStatus_t err); // The macro version that prints the command that failed. #define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) +#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) // Base class for RAII managed CUDA resources. template diff --git a/mlx/backend/cuda/cudnn_utils.cpp b/mlx/backend/cuda/cudnn_utils.cpp index a0f7633ca..499e34a45 100644 --- a/mlx/backend/cuda/cudnn_utils.cpp +++ b/mlx/backend/cuda/cudnn_utils.cpp @@ -7,32 +7,26 @@ namespace mlx::core { namespace { -// Create a cudnn tensor descriptor. -template -inline cudnn_frontend::Tensor build_cudnn_tensor( - int64_t id, - const array& x, - const Vec& shape, - const Vec& strides) { - return cudnn_frontend::TensorBuilder() - .setDim(shape.size(), shape.data()) - .setStrides(strides.size(), strides.data()) - .setId(id) - .setAlignment(get_alignment(x)) - .setDataType(dtype_to_cudnn_type(x.dtype())) - .build(); -} +#define RETURN_IF_ERROR(cmd) \ + if (auto ret = cmd; ret.is_bad()) { \ + return ret; \ + } // In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN // whether a tensor is contiguous is determined with: // shape[dim] == shape[dim + 1] * strides[dim + 1] // So a contiguous array with singleton dims in MLX may be mistakenly treated // as strided in cuDNN, and we work around it by normalizing the strides. -Strides normalized_strides(const array& x) { - if (!x.flags().row_contiguous || x.ndim() < 2) { - return x.strides(); +std::vector normalized_strides(const array& x) { + std::vector strides(x.strides().begin(), x.strides().end()); + if (std::all_of( + strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) { + strides.back() = 1; + return strides; + } + if (!x.flags().row_contiguous || x.ndim() < 2) { + return strides; } - Strides strides = x.strides(); for (int i = x.ndim() - 2; i >= 0; --i) { if (x.shape(i) == 1) { strides[i] = x.shape(i + 1) * strides[i + 1]; @@ -42,7 +36,9 @@ Strides normalized_strides(const array& x) { } // Return the shape and strides after transposing from NHWC to NCHW. -auto nhwc_to_nchw(SmallVector shape, SmallVector strides) { +inline auto nhwc_to_nchw(const array& x) { + auto shape = convert_vector(x.shape()); + auto strides = normalized_strides(x); assert(shape.size() >= 3); shape.insert(shape.begin() + 1, shape.back()); shape.erase(shape.end() - 1); @@ -51,226 +47,95 @@ auto nhwc_to_nchw(SmallVector shape, SmallVector strides) { return std::make_tuple(std::move(shape), std::move(strides)); } -inline auto nhwc_to_nchw(const array& x) { - return nhwc_to_nchw( - convert_vector(x.shape()), normalized_strides(x)); -} - -// Return available engines for a |op_graph|. -cudnn_frontend::EngineConfigList get_cudnn_engine_configs( - cudnnBackendDescriptorType_t backend_type, - Dtype dtype, - cudnn_frontend::OperationGraph& op_graph, - bool use_fallback = true) { - SmallVector sources; - sources.push_back([](auto& op_graph) { - auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() - .setOperationGraph(op_graph) - .setHeurMode(CUDNN_HEUR_MODE_A) - .build(); - return heuristics.getEngineConfig(heuristics.getEngineConfigCount()); - }); - if (use_fallback) { - sources.push_back([&backend_type](auto& op_graph) { - auto fallback = cudnn_frontend::EngineFallbackListBuilder() - .setOperationGraph(op_graph) - .setOperation(backend_type) - .build(); - return fallback.getFallbackList(); - }); - } - - auto configs = - cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data()) - .generate_engine_config(op_graph); - - cudnn_frontend::EngineConfigList filtered_configs; - cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) { - if (cudnn_frontend::hasNumericalNote< - CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) { - return true; - } - if (cudnn_frontend::hasNumericalNote(c) && - dtype == float32 && !env::enable_tf32()) { - return true; - } - return false; - }); - return filtered_configs; -} - -// Take |engine_configs| and |op_graph| and find a working execution plans -// from them. -std::optional -find_cudnn_plan_from_engine_configs( - cudnnHandle_t handle, - const cudnn_frontend::EngineConfigList& engine_configs, - const cudnn_frontend::OperationGraph& op_graph) { - auto op_graph_tag = op_graph.getTag(); - for (const auto& config : engine_configs) { - try { - return cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(config, op_graph_tag) - .build(); - } catch (cudnn_frontend::cudnnException& error) { - if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) { - throw; - } - } - } - return std::nullopt; -} - -// Prepare workspace and args to execute plan. -template -bool prepare_cudnn_plan( - cu::CommandEncoder& encoder, - cudnn_frontend::ExecutionPlan& plan, - int num_args, - const int64_t* uids, - void** data_ptrs, - F&& execute) { - int workspace_size = plan.getWorkspaceSize(); - void* workspace_ptr = nullptr; - if (workspace_size > 0) { - array workspace( - cu::malloc_async(workspace_size, encoder), {workspace_size}, uint8); - encoder.add_temporary(workspace); - workspace_ptr = gpu_ptr(workspace); - } - - auto args = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(num_args, data_ptrs) - .setUids(num_args, uids) - .build(); - - auto handle = encoder.device().cudnn_handle(); - cudnnSetStream(handle, encoder.stream()); - - if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) { - return false; - } - - return true; -} - } // namespace -cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) { - auto shape = convert_vector(x.shape()); - return build_cudnn_tensor(id, x, shape, normalized_strides(x)); +fe::error_t DnnGraph::prepare() { + RETURN_IF_ERROR(validate()); + try { + RETURN_IF_ERROR(build_operation_graph(handle_)); + } catch (cudnn_frontend::cudnnException& error) { + // cuDNN bug: they did not catch all exceptions in the API. + return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()}; + } + RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A})); + return {}; } -cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) { +fe::error_t DnnGraph::build() { + RETURN_IF_ERROR(check_support(handle_)); + RETURN_IF_ERROR(build_plans(handle_)); + return {}; +} + +fe::error_t DnnGraph::encode_graph( + cu::CommandEncoder& encoder, + std::unordered_map variant_pack) { + cudnnSetStream(handle_, encoder.stream()); + CudaGraph cuda_graph(encoder.device()); + RETURN_IF_ERROR(populate_cuda_graph( + handle_, variant_pack, prepare_workspace(encoder), cuda_graph)); + encoder.add_graph_node(cuda_graph); + return {}; +} + +fe::error_t DnnGraph::encode_capturing( + cu::CommandEncoder& encoder, + std::unordered_map variant_pack) { + auto* workspace_ptr = prepare_workspace(encoder); + auto capture = encoder.capture_context(); + cudnnSetStream(handle_, encoder.stream()); + auto ret = execute(handle_, variant_pack, workspace_ptr); + if (ret.is_bad()) { + capture.discard = true; + } + return ret; +} + +void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) { + int64_t workspace_size = 0; + CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size)); + if (workspace_size > 0) { + array workspace( + cu::malloc_async(workspace_size, encoder), + {static_cast(workspace_size)}, + uint8); + encoder.add_temporary(workspace); + return gpu_ptr(workspace); + } + return nullptr; +} + +void DnnGraph::set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x, + const std::vector& shape, + const std::vector& strides) { + tensor->set_uid(uid) + .set_alignment(get_alignment(x)) + .set_data_type(dtype_to_cudnn_type(x.dtype())) + .set_dim(shape) + .set_stride(strides); +} + +void DnnGraph::set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x) { + set_tensor_attrs( + tensor, + uid, + x, + convert_vector(x.shape()), + normalized_strides(x)); +} + +void DnnGraph::set_tensor_attrs_nchw( + std::shared_ptr& tensor, + int64_t uid, + const array& x) { auto [shape, strides] = nhwc_to_nchw(x); - return build_cudnn_tensor(id, x, shape, strides); + set_tensor_attrs(tensor, uid, x, shape, strides); } -cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) { - if (x.ndim() == 0) { - SmallVector scalar_dims = {1, 1, 1, 1}; - return build_cudnn_tensor(id, x, scalar_dims, scalar_dims); - } - if (x.ndim() == 1) { - int64_t s = x.shape(0); - SmallVector shape = {1, x.shape(0), 1, 1}; - SmallVector strides = {s, 1, s, s}; - return build_cudnn_tensor(id, x, shape, strides); - } - if (x.ndim() == 2) { - int64_t s = - x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0); - SmallVector shape = {x.shape(0), x.shape(1), 1, 1}; - SmallVector strides = {s, x.strides(1), s, s}; - return build_cudnn_tensor(id, x, shape, strides); - } - if (x.ndim() == 3 || x.ndim() == 4) { - return build_cudnn_tensor_nchw(id, x); - } - throw std::runtime_error( - fmt::format("Unsupported array with {} dims.", x.ndim())); -} - -cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) { - SmallVector scalar_dims = {1, 1, 1, 1}; - return cudnn_frontend::TensorBuilder() - .setDim(scalar_dims.size(), scalar_dims.data()) - .setStrides(scalar_dims.size(), scalar_dims.data()) - .setId(id) - .setAlignment(16) - .setDataType(dtype_to_cudnn_type(dtype)) - .setByValue(true) - .build(); -} - -std::optional find_cudnn_plan_from_op_graph( - cudnnHandle_t handle, - cudnnBackendDescriptorType_t backend_type, - Dtype dtype, - cudnn_frontend::OperationGraph& op_graph) { - auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph); - if (engine_configs.empty()) { - return std::nullopt; - } - return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph); -} - -bool encode_cudnn_plan_with_capturing( - cu::CommandEncoder& encoder, - cudnn_frontend::ExecutionPlan& plan, - int num_args, - const int64_t* uids, - void** data_ptrs) { - return prepare_cudnn_plan( - encoder, - plan, - num_args, - uids, - data_ptrs, - [&](auto handle, auto plan, auto args) { - auto capture = encoder.capture_context(); - if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) { - // Discard the captured graph when failed. - capture.discard = true; - return false; - } - return true; - }); -} - -#if CUDNN_VERSION >= 90500 -bool encode_cudnn_plan_with_graph_api( - cu::CommandEncoder& encoder, - cudnn_frontend::ExecutionPlan& plan, - CudaGraph& graph, - int num_args, - const int64_t* uids, - void** data_ptrs) { - return prepare_cudnn_plan( - encoder, - plan, - num_args, - uids, - data_ptrs, - [&](auto handle, auto plan, auto args) { - if (!graph) { - graph = CudaGraph(encoder.device()); - if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) != - CUDNN_STATUS_SUCCESS) { - return false; - } - } else { - if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) != - CUDNN_STATUS_SUCCESS) { - return false; - } - } - encoder.add_graph_node(graph); - return true; - }); -} -#endif - } // namespace mlx::core diff --git a/mlx/backend/cuda/cudnn_utils.h b/mlx/backend/cuda/cudnn_utils.h index 537e17c44..5a3d2f9b1 100644 --- a/mlx/backend/cuda/cudnn_utils.h +++ b/mlx/backend/cuda/cudnn_utils.h @@ -2,25 +2,30 @@ #pragma once -#include "mlx/array.h" -#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/utils.h" #include "mlx/dtype_utils.h" #include -#include #include -#include -#include - namespace mlx::core { namespace cu { class CommandEncoder; } +namespace fe = cudnn_frontend; + +#define CHECK_CUDNN_FE_ERROR(cmd) \ + do { \ + auto error = cmd; \ + if (!error.is_good()) { \ + throw std::runtime_error( \ + fmt::format("{} failed: {}.", #cmd, error.get_message())); \ + } \ + } while (0) + // Return pointer alignment of |x|'s data. inline uint8_t get_alignment(const array& x) { uint8_t alignment = 1; @@ -35,8 +40,31 @@ inline uint8_t get_alignment(const array& x) { // Convert the type of elements in |vec| to |T|. template -inline SmallVector convert_vector(const Vec& vec) { - return SmallVector(vec.begin(), vec.end()); +inline std::vector convert_vector(const Vec& vec) { + return std::vector(vec.begin(), vec.end()); +} + +// Map dtype to cudnn data type. +inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) { + switch (dtype) { + case int8: + return fe::DataType_t::INT8; + case int32: + return fe::DataType_t::INT32; + case uint8: + return fe::DataType_t::UINT8; + case float16: + return fe::DataType_t::HALF; + case bfloat16: + return fe::DataType_t::BFLOAT16; + case float32: + return fe::DataType_t::FLOAT; + case float64: + return fe::DataType_t::DOUBLE; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype))); + } } // Return an array that can be used as map key for |vec| with size <= MAX_NDIM. @@ -55,111 +83,89 @@ inline std::array vector_key(const Vec& vec) { return result; } -// Helpers used by get_data_ptrs to get pointers. -inline void* get_data_ptr(const array& arr) { - return const_cast(gpu_ptr(arr)); -} - -template >> -inline void* get_data_ptr(T& scalar) { - return &scalar; -} - -// Return an array filled with data pointers of args. -template -inline std::array get_data_ptrs(Args&... args) { - return {get_data_ptr(args)...}; -} - -// Map dtype to cudnn data type. -inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) { - switch (dtype) { - case int8: - return CUDNN_DATA_INT8; - case int32: - return CUDNN_DATA_INT32; - case uint8: - return CUDNN_DATA_UINT8; - case float16: - return CUDNN_DATA_HALF; - case bfloat16: - return CUDNN_DATA_BFLOAT16; - case float32: - return CUDNN_DATA_FLOAT; - case float64: - return CUDNN_DATA_DOUBLE; - default: - throw std::runtime_error(fmt::format( - "Unsupported dtype in Convolution: {}.", dtype_to_string(dtype))); +// Extends cuDNN graph with helpers. +class DnnGraph : public fe::graph::Graph { + public: + DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32) + : handle_(handle) { + set_io_data_type(dtype_to_cudnn_type(io_dtype)); + set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype)); + set_compute_data_type(dtype_to_cudnn_type(compute_dtype)); } -} -// Create a tensor descriptor from |x|. -cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x); + // Create a cuDNN tensor description from MLX array |x|. + auto& tensor( + std::shared_ptr& attrs, + int64_t uid, + const array& x) { + set_tensor_attrs(attrs, uid, x); + return attrs; + } + auto tensor(const char* name, int64_t uid, const array& x) { + auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); + tensor(attrs, uid, x); + return attrs; + } -// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW. -cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x); + // Create a cuDNN tensor description from MLX array |x|, and transpose it from + // NHWC layout to NCHW. + auto& tensor_nchw( + std::shared_ptr& attrs, + int64_t uid, + const array& x) { + set_tensor_attrs_nchw(attrs, uid, x); + return attrs; + } + auto tensor_nchw(const char* name, int64_t uid, const array& x) { + auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); + tensor_nchw(attrs, uid, x); + return attrs; + } -// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it -// from NHWC to NCHW. -cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x); + // Create a cuDNN tensor for scalar. + auto scalar(const char* name, int64_t uid, Dtype dtype) { + return Graph::tensor(fe::graph::Tensor_attributes() + .set_name(name) + .set_uid(uid) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(dtype_to_cudnn_type(dtype))); + } -// Create a 4D scalar tensor descriptor, which is passed by value. -cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype); + // Call this before setting notes. + fe::error_t prepare(); + // Call this after setting notes. + fe::error_t build(); -// Find a working plan for |op_graph|. -std::optional find_cudnn_plan_from_op_graph( - cudnnHandle_t handle, - cudnnBackendDescriptorType_t backend_type, - Dtype dtype, - cudnn_frontend::OperationGraph& op_graph); + // Add cuDNN graph to CUDA graph, using native CUDA graph API. + fe::error_t encode_graph( + cu::CommandEncoder& encoder, + std::unordered_map variant_pack); + // Add cuDNN graph to CUDA graph, using stream capture. + fe::error_t encode_capturing( + cu::CommandEncoder& encoder, + std::unordered_map variant_pack); -// Encode the plan to command buffer by capturing. -bool encode_cudnn_plan_with_capturing( - cu::CommandEncoder& encoder, - cudnn_frontend::ExecutionPlan& plan, - int num_args, - const int64_t* uids, - void** data_ptrs); + private: + void* prepare_workspace(cu::CommandEncoder& encoder); -#if CUDNN_VERSION >= 90500 -// Encode the plan to command buffer by using native graph api of cudnn. If the -// |graph| is empty it will be populated, otherwise it will be updated. -bool encode_cudnn_plan_with_graph_api( - cu::CommandEncoder& encoder, - cudnn_frontend::ExecutionPlan& plan, - CudaGraph& graph, - int num_args, - const int64_t* uids, - void** data_ptrs); -#endif + void set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x, + const std::vector& shape, + const std::vector& strides); + void set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x); + void set_tensor_attrs_nchw( + std::shared_ptr& tensor, + int64_t uid, + const array& x); -// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z). -template -bool encode_cudnn_plan( - cu::CommandEncoder& encoder, - cudnn_frontend::ExecutionPlan& plan, - std::initializer_list uids, - Args&... args) { - assert(uids.size() == sizeof...(args)); - auto data_ptrs = get_data_ptrs(args...); - return encode_cudnn_plan_with_capturing( - encoder, plan, uids.size(), uids.begin(), data_ptrs.data()); -} - -#if CUDNN_VERSION >= 90500 -template -bool encode_cudnn_plan( - cu::CommandEncoder& encoder, - cudnn_frontend::ExecutionPlan& plan, - CudaGraph& graph, - std::initializer_list uids, - Args&... args) { - assert(uids.size() == sizeof...(args)); - auto data_ptrs = get_data_ptrs(args...); - return encode_cudnn_plan_with_graph_api( - encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data()); -} -#endif + cudnnHandle_t handle_; +}; } // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 7c4f2ab1e..beb42052a 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -14,15 +14,6 @@ namespace mlx::core::cu { namespace { -#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) - -void check_cudnn_error(const char* name, cudnnStatus_t err) { - if (err != CUDNN_STATUS_SUCCESS) { - throw std::runtime_error( - fmt::format("{} failed: {}.", name, cudnnGetErrorString(err))); - } -} - bool use_cuda_graphs() { static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true); return use_graphs; diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 6a76bb131..108b0244a 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -10,46 +10,8 @@ namespace mlx::core { -namespace fe = cudnn_frontend; - namespace { -#define CHECK_CUDNN_FE_ERROR(cmd) \ - do { \ - auto error = cmd; \ - if (!error.is_good()) { \ - throw std::runtime_error( \ - fmt::format("{} failed: {}.", #cmd, error.get_message())); \ - } \ - } while (0) - -std::vector normalized_strides(const array& x) { - std::vector strides(x.strides().begin(), x.strides().end()); - if (std::all_of( - strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) { - strides.back() = 1; - return strides; - } - if (!x.flags().row_contiguous || x.ndim() < 2) { - return strides; - } - for (int i = x.ndim() - 2; i >= 0; --i) { - if (x.shape(i) == 1) { - strides[i] = x.shape(i + 1) * strides[i + 1]; - } - } - return strides; -} - -void set_tensor_attrs( - std::shared_ptr& tensor, - int64_t uid, - const array& x) { - tensor->set_uid(uid) - .set_dim({x.shape().begin(), x.shape().end()}) - .set_stride(normalized_strides(x)); -} - array prepare_sdpa_input(const array& x, Stream s) { // SDPA kernel's requirements on inputs: // 1. last dim's stride be 1; @@ -99,7 +61,7 @@ constexpr int QKV_NDIM = 4; struct SDPACacheKey { int device_id; - cudnnDataType_t cudnn_dtype; + fe::DataType_t cudnn_dtype; std::array q_shape; std::array k_shape; std::array v_shape; @@ -143,13 +105,13 @@ inline BytesKey build_sdpa_cache_key( } auto& sdpa_cache() { - static LRUBytesKeyCache cache( + static LRUBytesKeyCache cache( "MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64); return cache; } auto& sdpa_backward_cache() { - static LRUBytesKeyCache cache( + static LRUBytesKeyCache cache( "MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64); return cache; } @@ -169,7 +131,7 @@ enum UIDS { D_O, }; -fe::graph::Graph build_sdpa_graph( +DnnGraph build_sdpa_graph( cudnnHandle_t handle, const array& q, const array& k, @@ -179,34 +141,15 @@ fe::graph::Graph build_sdpa_graph( bool output_logsumexp, const array& o, const array& stats) { - auto dtype = fe::DataType_t::HALF; - if (q.dtype() == bfloat16) { - dtype = fe::DataType_t::BFLOAT16; - } + DnnGraph graph(handle, q.dtype()); - fe::graph::Graph graph; - graph.set_io_data_type(dtype) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q")); - auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K")); - auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V")); - set_tensor_attrs(q_, Q, q); - set_tensor_attrs(k_, K, k); - set_tensor_attrs(v_, V, v); - - auto scale = graph.tensor(fe::graph::Tensor_attributes() - .set_name("Scale") - .set_uid(SCALE) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); + auto q_ = graph.tensor("Q", Q, q); + auto k_ = graph.tensor("K", K, k); + auto v_ = graph.tensor("V", V, v); auto options = fe::graph::SDPA_attributes() .set_name("sdpa_cudnn") - .set_attn_scale(scale) + .set_attn_scale(graph.scalar("Scale", SCALE, float32)) .set_generate_stats(output_logsumexp); if (do_causal) { if (q.shape(2) > k.shape(2)) { @@ -216,31 +159,23 @@ fe::graph::Graph build_sdpa_graph( } } if (mask_arr) { - auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS")); - set_tensor_attrs(bias_, BIAS, *mask_arr); - options.set_bias(bias_); + options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr)); } auto [o_, stats_] = graph.sdpa(q_, k_, v_, options); - o_->set_output(true); - set_tensor_attrs(o_, O, o); + graph.tensor(o_, O, o)->set_output(true); if (output_logsumexp) { - stats_->set_output(true).set_data_type(fe::DataType_t::FLOAT); - set_tensor_attrs(stats_, STATS, stats); + graph.tensor(stats_, STATS, stats)->set_output(true); } - CHECK_CUDNN_FE_ERROR(graph.validate()); - CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle)); - CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A})); + CHECK_CUDNN_FE_ERROR(graph.prepare()); graph.select_behavior_notes( {fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); - CHECK_CUDNN_FE_ERROR(graph.check_support(handle)); - CHECK_CUDNN_FE_ERROR(graph.build_plans(handle)); - + CHECK_CUDNN_FE_ERROR(graph.build()); return graph; } -fe::graph::Graph build_sdpa_backward_graph( +DnnGraph build_sdpa_backward_graph( cudnnHandle_t handle, const array& q, const array& k, @@ -253,42 +188,18 @@ fe::graph::Graph build_sdpa_backward_graph( array& d_q, array& d_k, array& d_v) { - auto dtype = fe::DataType_t::HALF; - if (q.dtype() == bfloat16) { - dtype = fe::DataType_t::BFLOAT16; - } + DnnGraph graph(handle, q.dtype()); - fe::graph::Graph graph; - graph.set_io_data_type(dtype) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q")); - auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K")); - auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V")); - auto o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("O")); - auto d_o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("D_O")); - auto stats_ = graph.tensor(fe::graph::Tensor_attributes().set_name("STATS")); - set_tensor_attrs(q_, Q, q); - set_tensor_attrs(k_, K, k); - set_tensor_attrs(v_, V, v); - set_tensor_attrs(o_, O, o); - set_tensor_attrs(d_o_, D_O, d_o); - set_tensor_attrs(stats_, STATS, stats); - stats_->set_data_type(fe::DataType_t::FLOAT); - - auto scale = graph.tensor(fe::graph::Tensor_attributes() - .set_name("Scale") - .set_uid(SCALE) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); + auto q_ = graph.tensor("Q", Q, q); + auto k_ = graph.tensor("K", K, k); + auto v_ = graph.tensor("V", V, v); + auto o_ = graph.tensor("O", O, o); + auto d_o_ = graph.tensor("D_O", D_O, d_o); + auto stats_ = graph.tensor("STATS", STATS, stats); auto options = fe::graph::SDPA_backward_attributes() .set_name("sdpa_backward_cudnn") - .set_attn_scale(scale) - .set_attn_scale(scale); + .set_attn_scale(graph.scalar("Scale", SCALE, float32)); if (do_causal) { if (q.shape(2) > k.shape(2)) { options.set_causal_mask(do_causal); @@ -297,56 +208,22 @@ fe::graph::Graph build_sdpa_backward_graph( } } if (mask_arr) { - auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS")); - set_tensor_attrs(bias_, BIAS, *mask_arr); - options.set_bias(bias_); + options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr)); } auto [d_q_, d_k_, d_v_] = graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options); - d_q_->set_output(true); - d_k_->set_output(true); - d_v_->set_output(true); - set_tensor_attrs(d_q_, D_Q, d_q); - set_tensor_attrs(d_k_, D_K, d_k); - set_tensor_attrs(d_v_, D_V, d_v); + graph.tensor(d_q_, D_Q, d_q)->set_output(true); + graph.tensor(d_k_, D_K, d_k)->set_output(true); + graph.tensor(d_v_, D_V, d_v)->set_output(true); - CHECK_CUDNN_FE_ERROR(graph.validate()); - CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle)); - CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A})); + CHECK_CUDNN_FE_ERROR(graph.prepare()); graph.select_behavior_notes( {fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); - CHECK_CUDNN_FE_ERROR(graph.check_support(handle)); - CHECK_CUDNN_FE_ERROR(graph.build_plans(handle)); - + CHECK_CUDNN_FE_ERROR(graph.build()); return graph; } -void execute_graph( - cu::CommandEncoder& encoder, - cudnnHandle_t handle, - fe::graph::Graph& graph, - std::unordered_map& variant_pack) { - int64_t workspace_size = 0; - CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size)); - void* workspace_ptr = nullptr; - if (workspace_size > 0) { - array workspace( - cu::malloc_async(workspace_size, encoder), - {static_cast(workspace_size)}, - uint8); - encoder.add_temporary(workspace); - workspace_ptr = gpu_ptr(workspace); - } - - cudnnSetStream(handle, encoder.stream()); - - CudaGraph cuda_graph(encoder.device()); - CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph( - handle, variant_pack, workspace_ptr, cuda_graph)); - encoder.add_graph_node(cuda_graph); -} - } // namespace bool supports_sdpa_cudnn( @@ -420,19 +297,19 @@ void sdpa_cudnn( auto& graph = it->second; std::unordered_map variant_pack{ - {Q, const_cast(gpu_ptr(q))}, - {K, const_cast(gpu_ptr(k))}, - {V, const_cast(gpu_ptr(v))}, + {Q, gpu_ptr(q)}, + {K, gpu_ptr(k)}, + {V, gpu_ptr(v)}, {SCALE, &scale}, {O, gpu_ptr(o)}}; if (mask_arr) { - variant_pack[BIAS] = const_cast(gpu_ptr(*mask_arr)); + variant_pack[BIAS] = gpu_ptr(*mask_arr); } if (output_logsumexp) { variant_pack[STATS] = gpu_ptr(stats); } - execute_graph(encoder, handle, graph, variant_pack); + CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack))); } void sdpa_backward_cudnn( @@ -480,21 +357,21 @@ void sdpa_backward_cudnn( auto& graph = it->second; std::unordered_map variant_pack{ - {Q, const_cast(gpu_ptr(q))}, - {K, const_cast(gpu_ptr(k))}, - {V, const_cast(gpu_ptr(v))}, + {Q, gpu_ptr(q)}, + {K, gpu_ptr(k)}, + {V, gpu_ptr(v)}, {SCALE, &scale}, - {O, const_cast(gpu_ptr(o))}, - {STATS, const_cast(gpu_ptr(stats))}, - {D_O, const_cast(gpu_ptr(d_o))}, + {O, gpu_ptr(o)}, + {STATS, gpu_ptr(stats)}, + {D_O, gpu_ptr(d_o)}, {D_Q, gpu_ptr(d_q)}, {D_K, gpu_ptr(d_k)}, {D_V, gpu_ptr(d_v)}}; if (mask_arr) { - variant_pack[BIAS] = const_cast(gpu_ptr(*mask_arr)); + variant_pack[BIAS] = gpu_ptr(*mask_arr); } - execute_graph(encoder, handle, graph, variant_pack); + CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack))); } // Defined in scaled_dot_product_attention.cu file. diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index eb74ae14a..35078fdf6 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -32,6 +32,13 @@ void check_cuda_error(const char* name, CUresult err) { } } +void check_cudnn_error(const char* name, cudnnStatus_t err) { + if (err != CUDNN_STATUS_SUCCESS) { + throw std::runtime_error( + fmt::format("{} failed: {}.", name, cudnnGetErrorString(err))); + } +} + const char* dtype_to_cuda_type(const Dtype& dtype) { switch (dtype) { case bool_: diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 417f7c8aa..b060880b5 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -31,8 +31,10 @@ inline T* gpu_ptr(array& arr) { arr.offset()); } +// For const array, keep constness in pointer unless it is untyped. template -inline const T* gpu_ptr(const array& arr) { +inline std::conditional_t, void*, const T*> gpu_ptr( + const array& arr) { return gpu_ptr(const_cast(arr)); }