From 65d0d402321f8cacaf0143591d77f4025b44b6b9 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 20 Aug 2025 09:29:28 +0900 Subject: [PATCH] Split cuDNN helpers into a separate header (#2491) * Add RAII managed CudaGraph class * Implement forward rms_norm with cuDNN * Revert back to old rms norm kernel --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/conv.cpp | 248 +++--------------------------- mlx/backend/cuda/cudnn_utils.cpp | 252 +++++++++++++++++++++++++++++++ mlx/backend/cuda/cudnn_utils.h | 164 ++++++++++++++++++++ mlx/backend/cuda/device.cpp | 14 +- mlx/backend/cuda/device.h | 4 +- mlx/backend/cuda/utils.cpp | 50 +++--- mlx/backend/cuda/utils.h | 96 +++++++----- 8 files changed, 527 insertions(+), 302 deletions(-) create mode 100644 mlx/backend/cuda/cudnn_utils.cpp create mode 100644 mlx/backend/cuda/cudnn_utils.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 38708be5e..c529af1d2 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -17,6 +17,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 29d4803a8..4c8016e9d 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -1,15 +1,11 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/cudnn_utils.h" #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/gpu/copy.h" -#include "mlx/dtype_utils.h" #include "mlx/primitives.h" -#include -#include -#include #include #include @@ -18,9 +14,6 @@ namespace mlx::core { namespace { -// Not all engines support it so can not use this API now. -#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0 - // Alias for better readability. #define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR #define CONV_BACKWARD_INPUT \ @@ -52,198 +45,6 @@ auto& conv_cache() { return cache; } -template -inline SmallVector convert_vector(const Vec& vec) { - return SmallVector(vec.begin(), vec.end()); -} - -template class Vec> -inline std::array fixed_vector(const Vec& vec) { - if (vec.size() > MAX_NDIM) { - throw std::runtime_error( - fmt::format("ndim can not be larger than {}.", MAX_NDIM)); - } - std::array result = {}; - std::copy_n(vec.begin(), vec.size(), result.begin()); - return result; -} - -auto nhwc_to_nchw(const array& x) { - auto shape = convert_vector(x.shape()); - shape.insert(shape.begin() + 1, shape.back()); - shape.erase(shape.end() - 1); - auto strides = convert_vector(x.strides()); - strides.insert(strides.begin() + 1, strides.back()); - strides.erase(strides.end() - 1); - return std::make_tuple(std::move(shape), std::move(strides)); -} - -inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) { - switch (dtype) { - case int8: - return CUDNN_DATA_INT8; - case int32: - return CUDNN_DATA_INT32; - case uint8: - return CUDNN_DATA_UINT8; - case float16: - return CUDNN_DATA_HALF; - case bfloat16: - return CUDNN_DATA_BFLOAT16; - case float32: - return CUDNN_DATA_FLOAT; - case float64: - return CUDNN_DATA_DOUBLE; - default: - throw std::runtime_error(fmt::format( - "Unsupported dtype in Convolution: {}.", dtype_to_string(dtype))); - } -} - -inline uint8_t get_alignment(const array& x) { - uint8_t alignment = 1; - uintptr_t address = reinterpret_cast(x.data()); - for (; alignment < 32; alignment *= 2) { - if (address % (alignment * 2)) { - return alignment; - } - } - return alignment; -} - -inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) { - auto [shape, strides] = nhwc_to_nchw(x); - return cudnn_frontend::TensorBuilder() - .setDim(shape.size(), shape.data()) - .setStrides(strides.size(), strides.data()) - .setId(id) - .setAlignment(get_alignment(x)) - .setDataType(dtype_to_cudnn_type(x.dtype())) - .build(); -} - -cudnn_frontend::EngineConfigList get_engine_configs( - cudnnBackendDescriptorType_t backend_type, - Dtype dtype, - cudnn_frontend::OperationGraph& op_graph, - bool use_fallback = false) { - cudnn_frontend::GeneratorSource source; - if (use_fallback) { - source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) { - auto fallback = cudnn_frontend::EngineFallbackListBuilder() - .setOperationGraph(op_graph) - .setOperation(backend_type) - .build(); - return fallback.getFallbackList(); - }; - } else { - source = [](cudnn_frontend::OperationGraph& op_graph) { - auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() - .setOperationGraph(op_graph) - .setHeurMode(CUDNN_HEUR_MODE_A) - .build(); - return heuristics.getEngineConfig(heuristics.getEngineConfigCount()); - }; - } - - cudnn_frontend::EngineConfigGenerator generator(1, &source); - auto configs = generator.generate_engine_config(op_graph); - - cudnn_frontend::EngineConfigList filtered_configs; - cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) { - if (cudnn_frontend::hasNumericalNote< - CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) { - return true; - } - if (cudnn_frontend::hasNumericalNote(c) && - dtype == float32 && !env::enable_tf32()) { - return true; - } - return false; - }); - return filtered_configs; -} - -bool execute_plan( - cu::CommandEncoder& encoder, - cudnn_frontend::ExecutionPlan& plan, - array& x, - array& w, - array& y) { - int workspace_size = plan.getWorkspaceSize(); - array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8); - - int64_t uids[3] = {'x', 'w', 'y'}; - void* data_ptrs[3] = { - x.data(), - w.data(), - y.data(), - }; - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace.data()) - .setDataPointers(3, data_ptrs) - .setUids(3, uids) - .build(); - - auto handle = encoder.device().cudnn_handle(); - cudnnSetStream(handle, encoder.stream()); - -#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API - cudaGraph_t graph; - cudaGraphCreate(&graph, 0); - std::unique_ptr graph_freer( - &graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); }); - if (cudnnBackendPopulateCudaGraph( - handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) != - CUDNN_STATUS_SUCCESS) { - return false; - } - encoder.add_graph_node(graph); -#else - auto capture = encoder.capture_context(); - if (cudnnBackendExecute( - handle, plan.get_raw_desc(), variantPack.get_raw_desc()) != - CUDNN_STATUS_SUCCESS) { - // Discard the captured graph when failed. - capture.discard = true; - return false; - } -#endif - - encoder.add_temporary(workspace); - return true; -} - -bool try_engines( - cu::CommandEncoder& encoder, - const ConvCacheKey& cache_key, - cudnnBackendDescriptorType_t backend_type, - cudnn_frontend::EngineConfigList& configs, - const std::string& op_graph_tag, - array& x, - array& w, - array& y) { - for (auto& config : configs) { - try { - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(encoder.device().cudnn_handle()) - .setEngineConfig(config, op_graph_tag) - .build(); - if (execute_plan(encoder, plan, x, w, y)) { - conv_cache().emplace( - cache_key, std::make_pair(backend_type, std::move(plan))); - return true; - } - } catch (cudnn_frontend::cudnnException& error) { - if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) { - throw; - } - } - } - return false; -} - auto get_conv_op_settings( cudnnBackendDescriptorType_t backend_type, array& x, @@ -288,7 +89,7 @@ auto get_conv_op_settings( } } -std::optional build_op_graph( +std::optional build_conv_op_graph( cu::CommandEncoder& encoder, cudnnBackendDescriptorType_t backend_type, Dtype dtype, @@ -314,9 +115,9 @@ std::optional build_op_graph( .build(); auto op = cudnn_frontend::OperationBuilder(backend_type) - .setxDesc(build_tensor('x', x)) - .setwDesc(build_tensor('w', w)) - .setyDesc(build_tensor('y', y)) + .setxDesc(build_cudnn_tensor_nchw('x', x)) + .setwDesc(build_cudnn_tensor_nchw('w', w)) + .setyDesc(build_cudnn_tensor_nchw('y', y)) .setcDesc(conv_desc) .build(); @@ -478,12 +279,12 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { ConvCacheKey cache_key{ encoder.device().cuda_device(), dtype_to_cudnn_type(dtype), - fixed_vector(in.shape()), - fixed_vector(wt.shape()), - fixed_vector(kernel_strides_), - fixed_vector(padding_lo_), - fixed_vector(padding_hi_), - fixed_vector(kernel_dilation_), + vector_key(in.shape()), + vector_key(wt.shape()), + vector_key(kernel_strides_), + vector_key(padding_lo_), + vector_key(padding_hi_), + vector_key(kernel_dilation_), groups_, flip_, get_alignment(in), @@ -495,7 +296,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { prepare_args(encoder, backend_type, in, wt, out, groups_, s); register_args(encoder, backend_type, in, wt, out, out_); auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (!execute_plan(encoder, plan, x, w, y)) { + if (!encode_cudnn_plan(encoder, plan, {'x', 'w', 'y'}, x, w, y)) { throw std::runtime_error("[conv] Cached plan failed to execute."); } return; @@ -537,7 +338,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { padding_hi_, kernel_dilation_, input_dilation_); - op_graph = build_op_graph( + op_graph = build_conv_op_graph( encoder, try_backend, dtype, @@ -560,22 +361,21 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { throw std::runtime_error("[conv] Can not build op graph."); } - // Get ready to execute the graph. + // Setup inputs and outputs. register_args(encoder, backend_type, in, wt, out, out_); - // Try to run plans based on heuristics. - auto configs = get_engine_configs(backend_type, dtype, *op_graph); - auto tag = op_graph->getTag(); + // Find a plan for the graph and execute it. + auto plan = find_cudnn_plan_from_op_graph( + encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); + if (!plan) { + throw std::runtime_error("[conv] Unable to find an execution plan."); + } auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) { - return; + if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { + throw std::runtime_error("[conv] Failed to run execution plan."); } - // Then try fallback plans. - configs = get_engine_configs(backend_type, dtype, *op_graph); - if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) { - return; - } - throw std::runtime_error("[conv] Unable to find a working engine."); + conv_cache().emplace( + cache_key, std::make_pair(backend_type, std::move(*plan))); } } // namespace mlx::core diff --git a/mlx/backend/cuda/cudnn_utils.cpp b/mlx/backend/cuda/cudnn_utils.cpp new file mode 100644 index 000000000..76bcc5b0b --- /dev/null +++ b/mlx/backend/cuda/cudnn_utils.cpp @@ -0,0 +1,252 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cudnn_utils.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +namespace { + +// Create a cudnn tensor descriptor. +template +inline cudnn_frontend::Tensor build_cudnn_tensor( + int64_t id, + const array& x, + const Vec& shape, + const Vec& strides) { + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(get_alignment(x)) + .setDataType(dtype_to_cudnn_type(x.dtype())) + .build(); +} + +// Return the shape and strides after transposing from NHWC to NCHW. +auto nhwc_to_nchw(SmallVector shape, SmallVector strides) { + assert(shape.size() >= 3); + shape.insert(shape.begin() + 1, shape.back()); + shape.erase(shape.end() - 1); + strides.insert(strides.begin() + 1, strides.back()); + strides.erase(strides.end() - 1); + return std::make_tuple(std::move(shape), std::move(strides)); +} + +auto nhwc_to_nchw(const array& x) { + return nhwc_to_nchw(convert_vector(x.shape()), x.strides()); +} + +// Return available engines for a |op_graph|. +cudnn_frontend::EngineConfigList get_cudnn_engine_configs( + cudnnBackendDescriptorType_t backend_type, + Dtype dtype, + cudnn_frontend::OperationGraph& op_graph, + bool use_fallback = true) { + SmallVector sources; + sources.push_back([](auto& op_graph) { + auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() + .setOperationGraph(op_graph) + .setHeurMode(CUDNN_HEUR_MODE_A) + .build(); + return heuristics.getEngineConfig(heuristics.getEngineConfigCount()); + }); + if (use_fallback) { + sources.push_back([&backend_type](auto& op_graph) { + auto fallback = cudnn_frontend::EngineFallbackListBuilder() + .setOperationGraph(op_graph) + .setOperation(backend_type) + .build(); + return fallback.getFallbackList(); + }); + } + + auto configs = + cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data()) + .generate_engine_config(op_graph); + + cudnn_frontend::EngineConfigList filtered_configs; + cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) { + if (cudnn_frontend::hasNumericalNote< + CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) { + return true; + } + if (cudnn_frontend::hasNumericalNote(c) && + dtype == float32 && !env::enable_tf32()) { + return true; + } + return false; + }); + return filtered_configs; +} + +// Take |engine_configs| and |op_graph| and find a working execution plans +// from them. +std::optional +find_cudnn_plan_from_engine_configs( + cudnnHandle_t handle, + const cudnn_frontend::EngineConfigList& engine_configs, + const cudnn_frontend::OperationGraph& op_graph) { + auto op_graph_tag = op_graph.getTag(); + for (const auto& config : engine_configs) { + try { + return cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(config, op_graph_tag) + .build(); + } catch (cudnn_frontend::cudnnException& error) { + if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) { + throw; + } + } + } + return std::nullopt; +} + +// Prepare workspace and args to execute plan. +template +bool prepare_cudnn_plan( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + int num_args, + const int64_t* uids, + void** data_ptrs, + F&& execute) { + int workspace_size = plan.getWorkspaceSize(); + array workspace( + workspace_size > 0 ? allocator::malloc(workspace_size) + : allocator::Buffer(nullptr), + {workspace_size}, + uint8); + + auto args = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace.data()) + .setDataPointers(num_args, data_ptrs) + .setUids(num_args, uids) + .build(); + + auto handle = encoder.device().cudnn_handle(); + cudnnSetStream(handle, encoder.stream()); + + if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) { + return false; + } + + encoder.add_temporary(workspace); + return true; +} + +} // namespace + +cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) { + auto shape = convert_vector(x.shape()); + return build_cudnn_tensor(id, x, shape, x.strides()); +} + +cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) { + auto [shape, strides] = nhwc_to_nchw(x); + return build_cudnn_tensor(id, x, shape, strides); +} + +cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) { + if (x.ndim() == 0) { + SmallVector scalar_dims = {1, 1, 1, 1}; + return build_cudnn_tensor(id, x, scalar_dims, scalar_dims); + } + if (x.ndim() == 1) { + int64_t s = x.shape(0); + SmallVector shape = {1, x.shape(0), 1, 1}; + SmallVector strides = {s, 1, s, s}; + return build_cudnn_tensor(id, x, shape, strides); + } + if (x.ndim() == 2) { + int64_t s = x.strides(0); + SmallVector shape = {x.shape(0), x.shape(1), 1, 1}; + SmallVector strides = {s, x.strides(1), s, s}; + return build_cudnn_tensor(id, x, shape, strides); + } + if (x.ndim() == 3 || x.ndim() == 4) { + return build_cudnn_tensor_nchw(id, x); + } + throw std::runtime_error( + fmt::format("Unsupported array with {} dims.", x.ndim())); +} + +cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) { + SmallVector scalar_dims = {1, 1, 1, 1}; + return cudnn_frontend::TensorBuilder() + .setDim(scalar_dims.size(), scalar_dims.data()) + .setStrides(scalar_dims.size(), scalar_dims.data()) + .setId(id) + .setAlignment(16) + .setDataType(dtype_to_cudnn_type(dtype)) + .setByValue(true) + .build(); +} + +std::optional find_cudnn_plan_from_op_graph( + cudnnHandle_t handle, + cudnnBackendDescriptorType_t backend_type, + Dtype dtype, + cudnn_frontend::OperationGraph& op_graph) { + auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph); + return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph); +} + +bool encode_cudnn_plan_with_capturing( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + int num_args, + const int64_t* uids, + void** data_ptrs) { + return prepare_cudnn_plan( + encoder, + plan, + num_args, + uids, + data_ptrs, + [&](auto handle, auto plan, auto args) { + auto capture = encoder.capture_context(); + if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) { + // Discard the captured graph when failed. + capture.discard = true; + return false; + } + return true; + }); +} + +#if CUDNN_VERSION >= 90500 +bool encode_cudnn_plan_with_graph_api( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + CudaGraph& graph, + int num_args, + const int64_t* uids, + void** data_ptrs) { + return prepare_cudnn_plan( + encoder, + plan, + num_args, + uids, + data_ptrs, + [&](auto handle, auto plan, auto args) { + if (!graph) { + graph = CudaGraph(encoder.device()); + if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) != + CUDNN_STATUS_SUCCESS) { + return false; + } + } else { + if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) != + CUDNN_STATUS_SUCCESS) { + return false; + } + } + encoder.add_graph_node(graph); + return true; + }); +} +#endif + +} // namespace mlx::core diff --git a/mlx/backend/cuda/cudnn_utils.h b/mlx/backend/cuda/cudnn_utils.h new file mode 100644 index 000000000..c35c5cac9 --- /dev/null +++ b/mlx/backend/cuda/cudnn_utils.h @@ -0,0 +1,164 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/dtype_utils.h" + +#include +#include +#include + +#include +#include + +namespace mlx::core { + +namespace cu { +class CommandEncoder; +} + +// Return pointer alignment of |x|'s data. +inline uint8_t get_alignment(const array& x) { + uint8_t alignment = 1; + uintptr_t address = reinterpret_cast(x.data()); + for (; alignment < 32; alignment *= 2) { + if (address % (alignment * 2)) { + return alignment; + } + } + return alignment; +} + +// Convert the type of elements in |vec| to |T|. +template +inline SmallVector convert_vector(const Vec& vec) { + return SmallVector(vec.begin(), vec.end()); +} + +// Return an array that can be used as map key for |vec| with size <= MAX_NDIM. +// +// There are 2 differences from the const_param util from kernel_utils.cuh: +// 1. The rest of array is filled with 0. +// 2. This util can be used in .cpp files. +template class Vec> +inline std::array vector_key(const Vec& vec) { + if (vec.size() > MAX_NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", MAX_NDIM)); + } + std::array result = {}; + std::copy_n(vec.begin(), vec.size(), result.begin()); + return result; +} + +// Helpers used by get_data_ptrs to get pointers. +inline void* get_data_ptr(const array& arr) { + return const_cast(arr.data()); +} + +template >> +inline void* get_data_ptr(T& scalar) { + return &scalar; +} + +// Return an array filled with data pointers of args. +template +inline std::array get_data_ptrs(Args&... args) { + return {get_data_ptr(args)...}; +} + +// Map dtype to cudnn data type. +inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) { + switch (dtype) { + case int8: + return CUDNN_DATA_INT8; + case int32: + return CUDNN_DATA_INT32; + case uint8: + return CUDNN_DATA_UINT8; + case float16: + return CUDNN_DATA_HALF; + case bfloat16: + return CUDNN_DATA_BFLOAT16; + case float32: + return CUDNN_DATA_FLOAT; + case float64: + return CUDNN_DATA_DOUBLE; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in Convolution: {}.", dtype_to_string(dtype))); + } +} + +// Create a tensor descriptor from |x|. +cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x); + +// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW. +cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x); + +// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it +// from NHWC to NCHW. +cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x); + +// Create a 4D scalar tensor descriptor, which is passed by value. +cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype); + +// Find a working plan for |op_graph|. +std::optional find_cudnn_plan_from_op_graph( + cudnnHandle_t handle, + cudnnBackendDescriptorType_t backend_type, + Dtype dtype, + cudnn_frontend::OperationGraph& op_graph); + +// Encode the plan to command buffer by capturing. +bool encode_cudnn_plan_with_capturing( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + int num_args, + const int64_t* uids, + void** data_ptrs); + +#if CUDNN_VERSION >= 90500 +// Encode the plan to command buffer by using native graph api of cudnn. If the +// |graph| is empty it will be populated, otherwise it will be updated. +bool encode_cudnn_plan_with_graph_api( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + CudaGraph& graph, + int num_args, + const int64_t* uids, + void** data_ptrs); +#endif + +// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z). +template +bool encode_cudnn_plan( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + std::initializer_list uids, + Args&... args) { + assert(uids.size() == sizeof...(args)); + auto data_ptrs = get_data_ptrs(args...); + return encode_cudnn_plan_with_capturing( + encoder, plan, uids.size(), uids.begin(), data_ptrs.data()); +} + +#if CUDNN_VERSION >= 90500 +template +bool encode_cudnn_plan( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + CudaGraph& graph, + std::initializer_list uids, + Args&... args) { + assert(uids.size() == sizeof...(args)); + auto data_ptrs = get_data_ptrs(args...); + return encode_cudnn_plan_with_graph_api( + encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data()); +} +#endif + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 96b07502f..371ae020c 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -91,9 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { } CommandEncoder::CaptureContext::~CaptureContext() { - CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); - std::unique_ptr graph_freer( - &graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); }); + graph.end_capture(enc.stream()); if (discard) { return; } @@ -185,9 +183,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } CommandEncoder::CommandEncoder(Device& d) - : device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) { - CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); -} + : device_(d), + stream_(d), + graph_(d), + graph_cache_(cuda_graph_cache_size()) {} void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); @@ -311,8 +310,7 @@ void CommandEncoder::commit() { to_nodes_.clear(); graph_key_.clear(); node_map_.clear(); - CHECK_CUDA_ERROR(cudaGraphDestroy(graph_)); - CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); + graph_ = CudaGraph(device_); } // Put completion handlers in a batch. diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 5eb7fd4c1..7b0ff5629 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -21,7 +21,7 @@ class CommandEncoder { struct CaptureContext { CaptureContext(CommandEncoder& enc); ~CaptureContext(); - cudaGraph_t graph; + CudaGraph graph; CommandEncoder& enc; bool discard{false}; }; @@ -115,7 +115,7 @@ class CommandEncoder { Device& device_; CudaStream stream_; - cudaGraph_t graph_; + CudaGraph graph_; Worker worker_; char node_count_{0}; char graph_node_count_{0}; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 88940a234..09894d4ca 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -8,36 +8,6 @@ namespace mlx::core { -CudaStream::CudaStream(cu::Device& device) { - device.make_current(); - CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); -} - -CudaStream::~CudaStream() { - CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); -} - -CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {} - -CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) { - other.handle_ = nullptr; -}; - -CudaGraphExec::~CudaGraphExec() { - reset(); -} - -void CudaGraphExec::instantiate(cudaGraph_t graph) { - CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0)); -} - -void CudaGraphExec::reset() { - if (handle_ != nullptr) { - CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_)); - handle_ = nullptr; - } -} - void check_cublas_error(const char* name, cublasStatus_t err) { if (err != CUBLAS_STATUS_SUCCESS) { // TODO: Use cublasGetStatusString when it is widely available. @@ -96,4 +66,24 @@ const char* dtype_to_cuda_type(const Dtype& dtype) { } } +CudaGraph::CudaGraph(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0)); +} + +void CudaGraph::end_capture(cudaStream_t stream) { + assert(handle_ == nullptr); + CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_)); +} + +void CudaGraphExec::instantiate(cudaGraph_t graph) { + assert(handle_ == nullptr); + CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0)); +} + +CudaStream::CudaStream(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking)); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 555e15065..e811d5e6c 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -16,44 +16,6 @@ class Device; struct Dtype; -// Cuda stream managed with RAII. -class CudaStream { - public: - explicit CudaStream(cu::Device& device); - ~CudaStream(); - - CudaStream(const CudaStream&) = delete; - CudaStream& operator=(const CudaStream&) = delete; - - operator cudaStream_t() const { - return stream_; - } - - private: - cudaStream_t stream_; -}; - -// Move-able RAII handle of cudaGraphExec_t. -class CudaGraphExec { - public: - CudaGraphExec(cudaGraphExec_t handle = nullptr); - CudaGraphExec(CudaGraphExec&& other); - ~CudaGraphExec(); - - CudaGraphExec(const CudaGraphExec&) = delete; - CudaGraphExec& operator=(const CudaGraphExec&) = delete; - - void instantiate(cudaGraph_t graph); - void reset(); - - operator cudaGraphExec_t() const { - return handle_; - } - - private: - cudaGraphExec_t handle_; -}; - // Throw exception if the cuda API does not succeed. void check_cublas_error(const char* name, cublasStatus_t err); void check_cuda_error(const char* name, cudaError_t err); @@ -66,4 +28,62 @@ void check_cuda_error(const char* name, CUresult err); // Convert Dtype to CUDA C++ types. const char* dtype_to_cuda_type(const Dtype& dtype); +// Base class for RAII managed CUDA resources. +template +class CudaHandle { + public: + CudaHandle(Handle handle = nullptr) : handle_(handle) {} + + CudaHandle(CudaHandle&& other) : handle_(other.handle_) { + assert(this != &other); + other.handle_ = nullptr; + } + + ~CudaHandle() { + reset(); + } + + CudaHandle(const CudaHandle&) = delete; + CudaHandle& operator=(const CudaHandle&) = delete; + + CudaHandle& operator=(CudaHandle&& other) { + assert(this != &other); + reset(); + std::swap(handle_, other.handle_); + return *this; + } + + void reset() { + if (handle_ != nullptr) { + CHECK_CUDA_ERROR(Destroy(handle_)); + handle_ = nullptr; + } + } + + operator Handle() const { + return handle_; + } + + protected: + Handle handle_; +}; + +// Wrappers of CUDA resources. +class CudaGraph : public CudaHandle { + public: + using CudaHandle::CudaHandle; + explicit CudaGraph(cu::Device& device); + void end_capture(cudaStream_t stream); +}; + +class CudaGraphExec : public CudaHandle { + public: + void instantiate(cudaGraph_t graph); +}; + +class CudaStream : public CudaHandle { + public: + explicit CudaStream(cu::Device& device); +}; + } // namespace mlx::core