mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Split cuDNN helpers into a separate header (#2491)
* Add RAII managed CudaGraph class * Implement forward rms_norm with cuDNN * Revert back to old rms norm kernel
This commit is contained in:
parent
cea9369610
commit
65d0d40232
@ -17,6 +17,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/config.h"
|
|
||||||
#include "mlx/backend/cuda/lru_cache.h"
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
#include <cudnn_frontend.h>
|
|
||||||
#include <cudnn_frontend_find_plan.h>
|
|
||||||
#include <fmt/format.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -18,9 +14,6 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Not all engines support it so can not use this API now.
|
|
||||||
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
|
||||||
|
|
||||||
// Alias for better readability.
|
// Alias for better readability.
|
||||||
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||||
#define CONV_BACKWARD_INPUT \
|
#define CONV_BACKWARD_INPUT \
|
||||||
@ -52,198 +45,6 @@ auto& conv_cache() {
|
|||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Vec>
|
|
||||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
|
||||||
return SmallVector<T>(vec.begin(), vec.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, template <typename U> class Vec>
|
|
||||||
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
|
|
||||||
if (vec.size() > MAX_NDIM) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
|
||||||
}
|
|
||||||
std::array<T, MAX_NDIM> result = {};
|
|
||||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto nhwc_to_nchw(const array& x) {
|
|
||||||
auto shape = convert_vector<int64_t>(x.shape());
|
|
||||||
shape.insert(shape.begin() + 1, shape.back());
|
|
||||||
shape.erase(shape.end() - 1);
|
|
||||||
auto strides = convert_vector<int64_t>(x.strides());
|
|
||||||
strides.insert(strides.begin() + 1, strides.back());
|
|
||||||
strides.erase(strides.end() - 1);
|
|
||||||
return std::make_tuple(std::move(shape), std::move(strides));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
|
||||||
switch (dtype) {
|
|
||||||
case int8:
|
|
||||||
return CUDNN_DATA_INT8;
|
|
||||||
case int32:
|
|
||||||
return CUDNN_DATA_INT32;
|
|
||||||
case uint8:
|
|
||||||
return CUDNN_DATA_UINT8;
|
|
||||||
case float16:
|
|
||||||
return CUDNN_DATA_HALF;
|
|
||||||
case bfloat16:
|
|
||||||
return CUDNN_DATA_BFLOAT16;
|
|
||||||
case float32:
|
|
||||||
return CUDNN_DATA_FLOAT;
|
|
||||||
case float64:
|
|
||||||
return CUDNN_DATA_DOUBLE;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline uint8_t get_alignment(const array& x) {
|
|
||||||
uint8_t alignment = 1;
|
|
||||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
|
||||||
for (; alignment < 32; alignment *= 2) {
|
|
||||||
if (address % (alignment * 2)) {
|
|
||||||
return alignment;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return alignment;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
|
||||||
auto [shape, strides] = nhwc_to_nchw(x);
|
|
||||||
return cudnn_frontend::TensorBuilder()
|
|
||||||
.setDim(shape.size(), shape.data())
|
|
||||||
.setStrides(strides.size(), strides.data())
|
|
||||||
.setId(id)
|
|
||||||
.setAlignment(get_alignment(x))
|
|
||||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigList get_engine_configs(
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
Dtype dtype,
|
|
||||||
cudnn_frontend::OperationGraph& op_graph,
|
|
||||||
bool use_fallback = false) {
|
|
||||||
cudnn_frontend::GeneratorSource source;
|
|
||||||
if (use_fallback) {
|
|
||||||
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
|
||||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
|
||||||
.setOperationGraph(op_graph)
|
|
||||||
.setOperation(backend_type)
|
|
||||||
.build();
|
|
||||||
return fallback.getFallbackList();
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
|
||||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
|
||||||
.setOperationGraph(op_graph)
|
|
||||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
|
||||||
.build();
|
|
||||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
|
||||||
auto configs = generator.generate_engine_config(op_graph);
|
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigList filtered_configs;
|
|
||||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
|
||||||
if (cudnn_frontend::hasNumericalNote<
|
|
||||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
|
||||||
dtype == float32 && !env::enable_tf32()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
return filtered_configs;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool execute_plan(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
array& x,
|
|
||||||
array& w,
|
|
||||||
array& y) {
|
|
||||||
int workspace_size = plan.getWorkspaceSize();
|
|
||||||
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
|
||||||
|
|
||||||
int64_t uids[3] = {'x', 'w', 'y'};
|
|
||||||
void* data_ptrs[3] = {
|
|
||||||
x.data<void>(),
|
|
||||||
w.data<void>(),
|
|
||||||
y.data<void>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
|
||||||
.setWorkspacePointer(workspace.data<void>())
|
|
||||||
.setDataPointers(3, data_ptrs)
|
|
||||||
.setUids(3, uids)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto handle = encoder.device().cudnn_handle();
|
|
||||||
cudnnSetStream(handle, encoder.stream());
|
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
|
||||||
cudaGraph_t graph;
|
|
||||||
cudaGraphCreate(&graph, 0);
|
|
||||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
|
||||||
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
|
||||||
if (cudnnBackendPopulateCudaGraph(
|
|
||||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
|
||||||
CUDNN_STATUS_SUCCESS) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
encoder.add_graph_node(graph);
|
|
||||||
#else
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
if (cudnnBackendExecute(
|
|
||||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
|
||||||
CUDNN_STATUS_SUCCESS) {
|
|
||||||
// Discard the captured graph when failed.
|
|
||||||
capture.discard = true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
encoder.add_temporary(workspace);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool try_engines(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const ConvCacheKey& cache_key,
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
cudnn_frontend::EngineConfigList& configs,
|
|
||||||
const std::string& op_graph_tag,
|
|
||||||
array& x,
|
|
||||||
array& w,
|
|
||||||
array& y) {
|
|
||||||
for (auto& config : configs) {
|
|
||||||
try {
|
|
||||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
|
||||||
.setHandle(encoder.device().cudnn_handle())
|
|
||||||
.setEngineConfig(config, op_graph_tag)
|
|
||||||
.build();
|
|
||||||
if (execute_plan(encoder, plan, x, w, y)) {
|
|
||||||
conv_cache().emplace(
|
|
||||||
cache_key, std::make_pair(backend_type, std::move(plan)));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} catch (cudnn_frontend::cudnnException& error) {
|
|
||||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_conv_op_settings(
|
auto get_conv_op_settings(
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
array& x,
|
array& x,
|
||||||
@ -288,7 +89,7 @@ auto get_conv_op_settings(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
@ -314,9 +115,9 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
|||||||
.build();
|
.build();
|
||||||
|
|
||||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||||
.setxDesc(build_tensor('x', x))
|
.setxDesc(build_cudnn_tensor_nchw('x', x))
|
||||||
.setwDesc(build_tensor('w', w))
|
.setwDesc(build_cudnn_tensor_nchw('w', w))
|
||||||
.setyDesc(build_tensor('y', y))
|
.setyDesc(build_cudnn_tensor_nchw('y', y))
|
||||||
.setcDesc(conv_desc)
|
.setcDesc(conv_desc)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
@ -478,12 +279,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
ConvCacheKey cache_key{
|
ConvCacheKey cache_key{
|
||||||
encoder.device().cuda_device(),
|
encoder.device().cuda_device(),
|
||||||
dtype_to_cudnn_type(dtype),
|
dtype_to_cudnn_type(dtype),
|
||||||
fixed_vector(in.shape()),
|
vector_key(in.shape()),
|
||||||
fixed_vector(wt.shape()),
|
vector_key(wt.shape()),
|
||||||
fixed_vector(kernel_strides_),
|
vector_key(kernel_strides_),
|
||||||
fixed_vector(padding_lo_),
|
vector_key(padding_lo_),
|
||||||
fixed_vector(padding_hi_),
|
vector_key(padding_hi_),
|
||||||
fixed_vector(kernel_dilation_),
|
vector_key(kernel_dilation_),
|
||||||
groups_,
|
groups_,
|
||||||
flip_,
|
flip_,
|
||||||
get_alignment(in),
|
get_alignment(in),
|
||||||
@ -495,7 +296,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
if (!execute_plan(encoder, plan, x, w, y)) {
|
if (!encode_cudnn_plan(encoder, plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@ -537,7 +338,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
padding_hi_,
|
padding_hi_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_);
|
input_dilation_);
|
||||||
op_graph = build_op_graph(
|
op_graph = build_conv_op_graph(
|
||||||
encoder,
|
encoder,
|
||||||
try_backend,
|
try_backend,
|
||||||
dtype,
|
dtype,
|
||||||
@ -560,22 +361,21 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
throw std::runtime_error("[conv] Can not build op graph.");
|
throw std::runtime_error("[conv] Can not build op graph.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get ready to execute the graph.
|
// Setup inputs and outputs.
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
|
|
||||||
// Try to run plans based on heuristics.
|
// Find a plan for the graph and execute it.
|
||||||
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
|
auto plan = find_cudnn_plan_from_op_graph(
|
||||||
auto tag = op_graph->getTag();
|
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
||||||
|
if (!plan) {
|
||||||
|
throw std::runtime_error("[conv] Unable to find an execution plan.");
|
||||||
|
}
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
return;
|
throw std::runtime_error("[conv] Failed to run execution plan.");
|
||||||
}
|
}
|
||||||
// Then try fallback plans.
|
conv_cache().emplace(
|
||||||
configs = get_engine_configs(backend_type, dtype, *op_graph);
|
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
||||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
throw std::runtime_error("[conv] Unable to find a working engine.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
252
mlx/backend/cuda/cudnn_utils.cpp
Normal file
252
mlx/backend/cuda/cudnn_utils.cpp
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Create a cudnn tensor descriptor.
|
||||||
|
template <typename Vec>
|
||||||
|
inline cudnn_frontend::Tensor build_cudnn_tensor(
|
||||||
|
int64_t id,
|
||||||
|
const array& x,
|
||||||
|
const Vec& shape,
|
||||||
|
const Vec& strides) {
|
||||||
|
return cudnn_frontend::TensorBuilder()
|
||||||
|
.setDim(shape.size(), shape.data())
|
||||||
|
.setStrides(strides.size(), strides.data())
|
||||||
|
.setId(id)
|
||||||
|
.setAlignment(get_alignment(x))
|
||||||
|
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||||
|
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
||||||
|
assert(shape.size() >= 3);
|
||||||
|
shape.insert(shape.begin() + 1, shape.back());
|
||||||
|
shape.erase(shape.end() - 1);
|
||||||
|
strides.insert(strides.begin() + 1, strides.back());
|
||||||
|
strides.erase(strides.end() - 1);
|
||||||
|
return std::make_tuple(std::move(shape), std::move(strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nhwc_to_nchw(const array& x) {
|
||||||
|
return nhwc_to_nchw(convert_vector<int64_t>(x.shape()), x.strides());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return available engines for a |op_graph|.
|
||||||
|
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph,
|
||||||
|
bool use_fallback = true) {
|
||||||
|
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
|
||||||
|
sources.push_back([](auto& op_graph) {
|
||||||
|
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||||
|
.build();
|
||||||
|
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||||
|
});
|
||||||
|
if (use_fallback) {
|
||||||
|
sources.push_back([&backend_type](auto& op_graph) {
|
||||||
|
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setOperation(backend_type)
|
||||||
|
.build();
|
||||||
|
return fallback.getFallbackList();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto configs =
|
||||||
|
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
|
||||||
|
.generate_engine_config(op_graph);
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList filtered_configs;
|
||||||
|
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||||
|
if (cudnn_frontend::hasNumericalNote<
|
||||||
|
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||||
|
dtype == float32 && !env::enable_tf32()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
return filtered_configs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take |engine_configs| and |op_graph| and find a working execution plans
|
||||||
|
// from them.
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan>
|
||||||
|
find_cudnn_plan_from_engine_configs(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
const cudnn_frontend::EngineConfigList& engine_configs,
|
||||||
|
const cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto op_graph_tag = op_graph.getTag();
|
||||||
|
for (const auto& config : engine_configs) {
|
||||||
|
try {
|
||||||
|
return cudnn_frontend::ExecutionPlanBuilder()
|
||||||
|
.setHandle(handle)
|
||||||
|
.setEngineConfig(config, op_graph_tag)
|
||||||
|
.build();
|
||||||
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare workspace and args to execute plan.
|
||||||
|
template <typename F>
|
||||||
|
bool prepare_cudnn_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs,
|
||||||
|
F&& execute) {
|
||||||
|
int workspace_size = plan.getWorkspaceSize();
|
||||||
|
array workspace(
|
||||||
|
workspace_size > 0 ? allocator::malloc(workspace_size)
|
||||||
|
: allocator::Buffer(nullptr),
|
||||||
|
{workspace_size},
|
||||||
|
uint8);
|
||||||
|
|
||||||
|
auto args = cudnn_frontend::VariantPackBuilder()
|
||||||
|
.setWorkspacePointer(workspace.data<void>())
|
||||||
|
.setDataPointers(num_args, data_ptrs)
|
||||||
|
.setUids(num_args, uids)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
cudnnSetStream(handle, encoder.stream());
|
||||||
|
|
||||||
|
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
||||||
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
|
return build_cudnn_tensor(id, x, shape, x.strides());
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
||||||
|
auto [shape, strides] = nhwc_to_nchw(x);
|
||||||
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
||||||
|
if (x.ndim() == 0) {
|
||||||
|
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||||
|
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
|
||||||
|
}
|
||||||
|
if (x.ndim() == 1) {
|
||||||
|
int64_t s = x.shape(0);
|
||||||
|
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
|
||||||
|
SmallVector<int64_t, 4> strides = {s, 1, s, s};
|
||||||
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
}
|
||||||
|
if (x.ndim() == 2) {
|
||||||
|
int64_t s = x.strides(0);
|
||||||
|
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
||||||
|
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
||||||
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
}
|
||||||
|
if (x.ndim() == 3 || x.ndim() == 4) {
|
||||||
|
return build_cudnn_tensor_nchw(id, x);
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("Unsupported array with {} dims.", x.ndim()));
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
|
||||||
|
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||||
|
return cudnn_frontend::TensorBuilder()
|
||||||
|
.setDim(scalar_dims.size(), scalar_dims.data())
|
||||||
|
.setStrides(scalar_dims.size(), scalar_dims.data())
|
||||||
|
.setId(id)
|
||||||
|
.setAlignment(16)
|
||||||
|
.setDataType(dtype_to_cudnn_type(dtype))
|
||||||
|
.setByValue(true)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
||||||
|
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool encode_cudnn_plan_with_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs) {
|
||||||
|
return prepare_cudnn_plan(
|
||||||
|
encoder,
|
||||||
|
plan,
|
||||||
|
num_args,
|
||||||
|
uids,
|
||||||
|
data_ptrs,
|
||||||
|
[&](auto handle, auto plan, auto args) {
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
|
||||||
|
// Discard the captured graph when failed.
|
||||||
|
capture.discard = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500
|
||||||
|
bool encode_cudnn_plan_with_graph_api(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
CudaGraph& graph,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs) {
|
||||||
|
return prepare_cudnn_plan(
|
||||||
|
encoder,
|
||||||
|
plan,
|
||||||
|
num_args,
|
||||||
|
uids,
|
||||||
|
data_ptrs,
|
||||||
|
[&](auto handle, auto plan, auto args) {
|
||||||
|
if (!graph) {
|
||||||
|
graph = CudaGraph(encoder.device());
|
||||||
|
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoder.add_graph_node(graph);
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
164
mlx/backend/cuda/cudnn_utils.h
Normal file
164
mlx/backend/cuda/cudnn_utils.h
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/utils.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cudnn_frontend.h>
|
||||||
|
#include <cudnn_frontend_find_plan.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
class CommandEncoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return pointer alignment of |x|'s data.
|
||||||
|
inline uint8_t get_alignment(const array& x) {
|
||||||
|
uint8_t alignment = 1;
|
||||||
|
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||||
|
for (; alignment < 32; alignment *= 2) {
|
||||||
|
if (address % (alignment * 2)) {
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the type of elements in |vec| to |T|.
|
||||||
|
template <typename T, typename Vec>
|
||||||
|
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||||
|
return SmallVector<T>(vec.begin(), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
||||||
|
//
|
||||||
|
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
||||||
|
// 1. The rest of array is filled with 0.
|
||||||
|
// 2. This util can be used in .cpp files.
|
||||||
|
template <typename T, template <typename U> class Vec>
|
||||||
|
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
||||||
|
if (vec.size() > MAX_NDIM) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||||
|
}
|
||||||
|
std::array<T, MAX_NDIM> result = {};
|
||||||
|
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers used by get_data_ptrs to get pointers.
|
||||||
|
inline void* get_data_ptr(const array& arr) {
|
||||||
|
return const_cast<void*>(arr.data<void>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
||||||
|
inline void* get_data_ptr(T& scalar) {
|
||||||
|
return &scalar;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return an array filled with data pointers of args.
|
||||||
|
template <typename... Args>
|
||||||
|
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
|
||||||
|
return {get_data_ptr(args)...};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map dtype to cudnn data type.
|
||||||
|
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case int8:
|
||||||
|
return CUDNN_DATA_INT8;
|
||||||
|
case int32:
|
||||||
|
return CUDNN_DATA_INT32;
|
||||||
|
case uint8:
|
||||||
|
return CUDNN_DATA_UINT8;
|
||||||
|
case float16:
|
||||||
|
return CUDNN_DATA_HALF;
|
||||||
|
case bfloat16:
|
||||||
|
return CUDNN_DATA_BFLOAT16;
|
||||||
|
case float32:
|
||||||
|
return CUDNN_DATA_FLOAT;
|
||||||
|
case float64:
|
||||||
|
return CUDNN_DATA_DOUBLE;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a tensor descriptor from |x|.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
|
||||||
|
|
||||||
|
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
|
||||||
|
|
||||||
|
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
|
||||||
|
// from NHWC to NCHW.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
|
||||||
|
|
||||||
|
// Create a 4D scalar tensor descriptor, which is passed by value.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
|
||||||
|
|
||||||
|
// Find a working plan for |op_graph|.
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph);
|
||||||
|
|
||||||
|
// Encode the plan to command buffer by capturing.
|
||||||
|
bool encode_cudnn_plan_with_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs);
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500
|
||||||
|
// Encode the plan to command buffer by using native graph api of cudnn. If the
|
||||||
|
// |graph| is empty it will be populated, otherwise it will be updated.
|
||||||
|
bool encode_cudnn_plan_with_graph_api(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
CudaGraph& graph,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
|
||||||
|
template <typename... Args>
|
||||||
|
bool encode_cudnn_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
std::initializer_list<int64_t> uids,
|
||||||
|
Args&... args) {
|
||||||
|
assert(uids.size() == sizeof...(args));
|
||||||
|
auto data_ptrs = get_data_ptrs(args...);
|
||||||
|
return encode_cudnn_plan_with_capturing(
|
||||||
|
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500
|
||||||
|
template <typename... Args>
|
||||||
|
bool encode_cudnn_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
CudaGraph& graph,
|
||||||
|
std::initializer_list<int64_t> uids,
|
||||||
|
Args&... args) {
|
||||||
|
assert(uids.size() == sizeof...(args));
|
||||||
|
auto data_ptrs = get_data_ptrs(args...);
|
||||||
|
return encode_cudnn_plan_with_graph_api(
|
||||||
|
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -91,9 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
graph.end_capture(enc.stream());
|
||||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
|
||||||
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
|
||||||
if (discard) {
|
if (discard) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -185,9 +183,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CommandEncoder(Device& d)
|
CommandEncoder::CommandEncoder(Device& d)
|
||||||
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
|
: device_(d),
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
stream_(d),
|
||||||
}
|
graph_(d),
|
||||||
|
graph_cache_(cuda_graph_cache_size()) {}
|
||||||
|
|
||||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||||
worker_.add_task(std::move(task));
|
worker_.add_task(std::move(task));
|
||||||
@ -311,8 +310,7 @@ void CommandEncoder::commit() {
|
|||||||
to_nodes_.clear();
|
to_nodes_.clear();
|
||||||
graph_key_.clear();
|
graph_key_.clear();
|
||||||
node_map_.clear();
|
node_map_.clear();
|
||||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
|
graph_ = CudaGraph(device_);
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put completion handlers in a batch.
|
// Put completion handlers in a batch.
|
||||||
|
@ -21,7 +21,7 @@ class CommandEncoder {
|
|||||||
struct CaptureContext {
|
struct CaptureContext {
|
||||||
CaptureContext(CommandEncoder& enc);
|
CaptureContext(CommandEncoder& enc);
|
||||||
~CaptureContext();
|
~CaptureContext();
|
||||||
cudaGraph_t graph;
|
CudaGraph graph;
|
||||||
CommandEncoder& enc;
|
CommandEncoder& enc;
|
||||||
bool discard{false};
|
bool discard{false};
|
||||||
};
|
};
|
||||||
@ -115,7 +115,7 @@ class CommandEncoder {
|
|||||||
|
|
||||||
Device& device_;
|
Device& device_;
|
||||||
CudaStream stream_;
|
CudaStream stream_;
|
||||||
cudaGraph_t graph_;
|
CudaGraph graph_;
|
||||||
Worker worker_;
|
Worker worker_;
|
||||||
char node_count_{0};
|
char node_count_{0};
|
||||||
char graph_node_count_{0};
|
char graph_node_count_{0};
|
||||||
|
@ -8,36 +8,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
CudaStream::CudaStream(cu::Device& device) {
|
|
||||||
device.make_current();
|
|
||||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaStream::~CudaStream() {
|
|
||||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {}
|
|
||||||
|
|
||||||
CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) {
|
|
||||||
other.handle_ = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
CudaGraphExec::~CudaGraphExec() {
|
|
||||||
reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaGraphExec::instantiate(cudaGraph_t graph) {
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaGraphExec::reset() {
|
|
||||||
if (handle_ != nullptr) {
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_));
|
|
||||||
handle_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||||
// TODO: Use cublasGetStatusString when it is widely available.
|
// TODO: Use cublasGetStatusString when it is widely available.
|
||||||
@ -96,4 +66,24 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CudaGraph::CudaGraph(cu::Device& device) {
|
||||||
|
device.make_current();
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaGraph::end_capture(cudaStream_t stream) {
|
||||||
|
assert(handle_ == nullptr);
|
||||||
|
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaGraphExec::instantiate(cudaGraph_t graph) {
|
||||||
|
assert(handle_ == nullptr);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaStream::CudaStream(cu::Device& device) {
|
||||||
|
device.make_current();
|
||||||
|
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -16,44 +16,6 @@ class Device;
|
|||||||
|
|
||||||
struct Dtype;
|
struct Dtype;
|
||||||
|
|
||||||
// Cuda stream managed with RAII.
|
|
||||||
class CudaStream {
|
|
||||||
public:
|
|
||||||
explicit CudaStream(cu::Device& device);
|
|
||||||
~CudaStream();
|
|
||||||
|
|
||||||
CudaStream(const CudaStream&) = delete;
|
|
||||||
CudaStream& operator=(const CudaStream&) = delete;
|
|
||||||
|
|
||||||
operator cudaStream_t() const {
|
|
||||||
return stream_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
cudaStream_t stream_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Move-able RAII handle of cudaGraphExec_t.
|
|
||||||
class CudaGraphExec {
|
|
||||||
public:
|
|
||||||
CudaGraphExec(cudaGraphExec_t handle = nullptr);
|
|
||||||
CudaGraphExec(CudaGraphExec&& other);
|
|
||||||
~CudaGraphExec();
|
|
||||||
|
|
||||||
CudaGraphExec(const CudaGraphExec&) = delete;
|
|
||||||
CudaGraphExec& operator=(const CudaGraphExec&) = delete;
|
|
||||||
|
|
||||||
void instantiate(cudaGraph_t graph);
|
|
||||||
void reset();
|
|
||||||
|
|
||||||
operator cudaGraphExec_t() const {
|
|
||||||
return handle_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
cudaGraphExec_t handle_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Throw exception if the cuda API does not succeed.
|
// Throw exception if the cuda API does not succeed.
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||||
void check_cuda_error(const char* name, cudaError_t err);
|
void check_cuda_error(const char* name, cudaError_t err);
|
||||||
@ -66,4 +28,62 @@ void check_cuda_error(const char* name, CUresult err);
|
|||||||
// Convert Dtype to CUDA C++ types.
|
// Convert Dtype to CUDA C++ types.
|
||||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||||
|
|
||||||
|
// Base class for RAII managed CUDA resources.
|
||||||
|
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||||
|
class CudaHandle {
|
||||||
|
public:
|
||||||
|
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
||||||
|
|
||||||
|
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
||||||
|
assert(this != &other);
|
||||||
|
other.handle_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
~CudaHandle() {
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaHandle(const CudaHandle&) = delete;
|
||||||
|
CudaHandle& operator=(const CudaHandle&) = delete;
|
||||||
|
|
||||||
|
CudaHandle& operator=(CudaHandle&& other) {
|
||||||
|
assert(this != &other);
|
||||||
|
reset();
|
||||||
|
std::swap(handle_, other.handle_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
if (handle_ != nullptr) {
|
||||||
|
CHECK_CUDA_ERROR(Destroy(handle_));
|
||||||
|
handle_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
operator Handle() const {
|
||||||
|
return handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Handle handle_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrappers of CUDA resources.
|
||||||
|
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
||||||
|
public:
|
||||||
|
using CudaHandle::CudaHandle;
|
||||||
|
explicit CudaGraph(cu::Device& device);
|
||||||
|
void end_capture(cudaStream_t stream);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
||||||
|
public:
|
||||||
|
void instantiate(cudaGraph_t graph);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||||
|
public:
|
||||||
|
explicit CudaStream(cu::Device& device);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user