Compare commits

...

2 Commits

Author SHA1 Message Date
Cheng
ac85ddfdb7 [CUDA] Add GEMM-based fallback convolution kernels (#2511)
* Add gemm_conv

* Add gemm_grouped_conv
2025-08-20 10:06:22 +09:00
Cheng
65d0d40232 Split cuDNN helpers into a separate header (#2491)
* Add RAII managed CudaGraph class

* Implement forward rms_norm with cuDNN

* Revert back to old rms norm kernel
2025-08-20 09:29:28 +09:00
14 changed files with 1182 additions and 322 deletions

View File

@@ -16,7 +16,10 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu

View File

@@ -1,15 +1,12 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
@@ -18,9 +15,6 @@ namespace mlx::core {
namespace {
// Not all engines support it so can not use this API now.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
// Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
@@ -28,6 +22,9 @@ namespace {
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
// Custom placeholder representing fallback kernel.
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
struct ConvCacheKey {
int device_id;
cudnnDataType_t cudnn_dtype;
@@ -47,203 +44,13 @@ struct ConvCacheKey {
auto& conv_cache() {
static LRUBytesKeyCache<
ConvCacheKey,
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
cache(/* capacity */ 128);
return cache;
}
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
}
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(x.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(std::move(shape), std::move(strides));
}
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
cudnn_frontend::EngineConfigList get_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = false) {
cudnn_frontend::GeneratorSource source;
if (use_fallback) {
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
};
} else {
source = [](cudnn_frontend::OperationGraph& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
};
}
cudnn_frontend::EngineConfigGenerator generator(1, &source);
auto configs = generator.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
bool execute_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
array& x,
array& w,
array& y) {
int workspace_size = plan.getWorkspaceSize();
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = {
x.data<void>(),
w.data<void>(),
y.data<void>(),
};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph;
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph(
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
encoder.add_graph_node(graph);
#else
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
#endif
encoder.add_temporary(workspace);
return true;
}
bool try_engines(
cu::CommandEncoder& encoder,
const ConvCacheKey& cache_key,
cudnnBackendDescriptorType_t backend_type,
cudnn_frontend::EngineConfigList& configs,
const std::string& op_graph_tag,
array& x,
array& w,
array& y) {
for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag)
.build();
if (execute_plan(encoder, plan, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(plan)));
return true;
}
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return false;
}
auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type,
array& x,
@@ -288,7 +95,7 @@ auto get_conv_op_settings(
}
}
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
@@ -314,9 +121,9 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', x))
.setwDesc(build_tensor('w', w))
.setyDesc(build_tensor('y', y))
.setxDesc(build_cudnn_tensor_nchw('x', x))
.setwDesc(build_cudnn_tensor_nchw('w', w))
.setyDesc(build_cudnn_tensor_nchw('y', y))
.setcDesc(conv_desc)
.build();
@@ -478,12 +285,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
ConvCacheKey cache_key{
encoder.device().cuda_device(),
dtype_to_cudnn_type(dtype),
fixed_vector(in.shape()),
fixed_vector(wt.shape()),
fixed_vector(kernel_strides_),
fixed_vector(padding_lo_),
fixed_vector(padding_hi_),
fixed_vector(kernel_dilation_),
vector_key(in.shape()),
vector_key(wt.shape()),
vector_key(kernel_strides_),
vector_key(padding_lo_),
vector_key(padding_hi_),
vector_key(kernel_dilation_),
groups_,
flip_,
get_alignment(in),
@@ -491,12 +298,29 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second;
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!execute_plan(encoder, plan, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
if (plan) {
// Run cached plan.
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
}
} else {
// Run fallback kernel.
gemm_conv(
encoder,
in,
wt,
out,
kernel_strides_,
padding_lo_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
s);
}
return;
}
@@ -537,7 +361,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_op_graph(
op_graph = build_conv_op_graph(
encoder,
try_backend,
dtype,
@@ -556,26 +380,39 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
break;
}
}
if (!op_graph) {
throw std::runtime_error("[conv] Can not build op graph.");
if (op_graph) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (!plan) {
throw std::runtime_error("[conv] Unable to find an execution plan.");
}
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
}
// Get ready to execute the graph.
register_args(encoder, backend_type, in, wt, out, out_);
// Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
auto tag = op_graph->getTag();
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return;
}
// Then try fallback plans.
configs = get_engine_configs(backend_type, dtype, *op_graph);
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return;
}
throw std::runtime_error("[conv] Unable to find a working engine.");
// Use fallback kernel for settings not supported by cuDNN.
gemm_conv(
encoder,
in,
wt,
out,
kernel_strides_,
padding_lo_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
s);
conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt));
}
} // namespace mlx::core

View File

@@ -0,0 +1,126 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
template <int NDIM>
struct ConvParams {
int N; // Batch size
int C; // In channels
int O; // Out channels
int strides[NDIM];
int padding[NDIM];
int kernel_dilation[NDIM];
int input_dilation[NDIM];
int groups;
bool flip;
int in_spatial_dims[NDIM];
int wt_spatial_dims[NDIM];
int out_spatial_dims[NDIM];
int64_t in_strides[NDIM + 2];
ConvParams(
const array& in,
const array& wt,
const array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip)
: N(in.shape(0)),
C(in.shape(-1)),
O(wt.shape(0)),
groups(groups),
flip(flip) {
std::copy_n(strides.begin(), NDIM, this->strides);
std::copy_n(padding.begin(), NDIM, this->padding);
std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);
std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);
std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);
std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);
std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);
std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);
}
};
void gemm_grouped_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s);
void gemm_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
bool flip,
Stream s);
inline void gemm_conv(
cu::CommandEncoder& encoder,
array in,
array wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s) {
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in);
}
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
encoder.add_temporary(wt);
}
if (groups == 1) {
gemm_conv(
encoder,
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
flip,
s);
} else {
gemm_grouped_conv(
encoder,
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
groups,
flip,
s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,217 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, int NDIM>
__global__ void naive_unfold_nd(
const T* in,
T* out,
int filter_size,
int out_pixels,
const __grid_constant__ ConvParams<NDIM> params) {
auto block = cg::this_thread_block();
auto tid = block.group_index();
auto lid = block.thread_index();
int index_batch = tid.z / out_pixels; // [0, N)
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
int index_wt_spatial =
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
if (index_wt_spatial >= filter_size / params.C) {
return;
}
in += tid.y; // [0, C)
out += tid.z * filter_size + index_wt_spatial * params.C + tid.y;
bool valid = index_batch < params.N;
// Get the coordinates in input.
int index_in[NDIM] = {};
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int index_out = index_out_spatial % params.out_spatial_dims[i];
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
if (params.flip) {
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
}
int index = index_out * params.strides[i] - params.padding[i] +
index_wt * params.kernel_dilation[i];
int index_max =
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
valid &= (index >= 0) && (index < index_max) &&
(index % params.input_dilation[i] == 0);
index_in[i] = index / params.input_dilation[i];
index_out_spatial /= params.out_spatial_dims[i];
index_wt_spatial /= params.wt_spatial_dims[i];
}
if (valid) {
int in_offset = index_batch * params.in_strides[0];
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
in_offset += index_in[i] * params.in_strides[i + 1];
}
*out = in[in_offset];
} else {
*out = T{0};
}
}
} // namespace cu
template <int NDIM>
array unfold_inputs_nd(
cu::CommandEncoder& encoder,
const array& in,
int mat_M,
int mat_K,
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
filter_size *= params.wt_spatial_dims[i];
}
int out_pixels = 1;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
out_pixels *= params.out_spatial_dims[i];
}
int wt_spatial_size = mat_K / params.C;
dim3 block_dims;
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
dim3 num_blocks;
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
num_blocks.y = params.C;
num_blocks.z = mat_M;
encoder.set_input_array(in);
encoder.set_output_array(unfolded);
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
encoder.add_kernel_node(
cu::naive_unfold_nd<DataType, NDIM>,
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
filter_size,
out_pixels,
params);
});
return unfolded;
}
template <int NDIM>
void gemm_conv_nd(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
ConvParams<NDIM>& params,
Stream s) {
// Get gemm shapes.
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
int mat_N = params.O; // O
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
array in_unfolded =
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
// Reshape weight to (C * H_wt * W_wt, O) for gemm.
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
wt_reshaped.copy_shared_buffer(
wt,
{1, mat_K},
{false, false, /* col_contiguous */ true},
wt.data_size());
// Single batch.
Shape batch_shape{1};
Strides a_batch_strides{0};
Strides b_batch_strides{0};
// Run matmul.
CublasGemm gemm(
encoder.device(),
in.dtype(),
false, // a_transposed
mat_M, // a_rows
mat_K, // a_cols
mat_K, // lda
true, // b_transposed
mat_K, // b_rows
mat_N, // b_cols
mat_K, // ldb
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(
encoder,
out,
in_unfolded,
wt_reshaped,
batch_shape,
a_batch_strides,
b_batch_strides);
}
void gemm_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
bool flip,
Stream s) {
int conv_ndim = in.ndim() - 2;
if (conv_ndim < 1 || conv_ndim > 3) {
throw std::runtime_error(
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
}
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
ConvParams<ndim_constant()> params(
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
1, // groups
flip);
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,231 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, int NDIM>
__global__ void naive_grouped_unfold_transpose_nd(
const T* in,
T* out,
int filter_size,
int out_pixels,
const __grid_constant__ ConvParams<NDIM> params) {
auto block = cg::this_thread_block();
auto tid = block.group_index();
auto lid = block.thread_index();
int index_batch = tid.z / out_pixels; // [0, N)
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
int index_wt_spatial =
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
if (index_wt_spatial >= filter_size / params.C) {
return;
}
in += tid.y; // [0, C)
out += tid.z * filter_size + tid.y * (filter_size / params.C);
bool valid = index_batch < params.N;
// Get the coordinates in input.
int index_in[NDIM] = {};
int wt_stride = 1;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int index_out = index_out_spatial % params.out_spatial_dims[i];
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
out += index_wt * wt_stride;
if (params.flip) {
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
}
int index = index_out * params.strides[i] - params.padding[i] +
index_wt * params.kernel_dilation[i];
int index_max =
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
valid &= (index >= 0) && (index < index_max) &&
(index % params.input_dilation[i] == 0);
index_in[i] = index / params.input_dilation[i];
index_out_spatial /= params.out_spatial_dims[i];
index_wt_spatial /= params.wt_spatial_dims[i];
wt_stride *= params.wt_spatial_dims[i];
}
if (valid) {
int in_offset = index_batch * params.in_strides[0];
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
in_offset += index_in[i] * params.in_strides[i + 1];
}
*out = in[in_offset];
} else {
*out = T{0};
}
}
} // namespace cu
template <int NDIM>
array grouped_unfold_transpose_inputs_nd(
cu::CommandEncoder& encoder,
const array& in,
int mat_M,
int mat_K,
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
filter_size *= params.wt_spatial_dims[i];
}
int out_pixels = 1;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
out_pixels *= params.out_spatial_dims[i];
}
int wt_spatial_size = (mat_K * params.groups) / params.C;
dim3 block_dims;
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
dim3 num_blocks;
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
num_blocks.y = params.C;
num_blocks.z = mat_M;
encoder.set_input_array(in);
encoder.set_output_array(unfolded);
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
encoder.add_kernel_node(
cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
filter_size,
out_pixels,
params);
});
return unfolded;
}
template <int NDIM>
void gemm_grouped_conv_nd(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
ConvParams<NDIM>& params,
Stream s) {
// Get gemm shapes.
int C_per_group = params.C / params.groups;
int O_per_group = params.O / params.groups;
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
int mat_N = O_per_group; // O_per_group
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(
encoder, in, mat_M, mat_K, mat_N, params);
// Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.
int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);
array wt_view(
{params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
array wt_reshaped = contiguous_copy_gpu(wt_view, s);
// Batch with size of groups.
Shape batch_shape{params.groups};
Strides a_batch_strides{mat_K};
Strides b_batch_strides{mat_N * mat_K};
// Run matmul.
CublasGemm gemm(
encoder.device(),
in.dtype(),
false, // a_transposed
mat_M, // a_rows
mat_K, // a_cols
mat_K * params.groups, // lda
true, // b_transposed
mat_K, // b_rows
mat_N, // b_cols
mat_K, // ldb
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.set_out(
out.dtype(),
false, // out_transposed
mat_M, // out_rows
mat_N, // out_cols
mat_N * params.groups, // out_ld
params.groups, // batch_count
mat_N); // batch_stride
gemm.run(
encoder,
out,
in_unfolded,
wt_reshaped,
batch_shape,
a_batch_strides,
b_batch_strides);
}
void gemm_grouped_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s) {
int conv_ndim = in.ndim() - 2;
if (conv_ndim < 1 || conv_ndim > 3) {
throw std::runtime_error(
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
}
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
ConvParams<ndim_constant()> params(
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
groups,
flip);
gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,252 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
namespace {
// Create a cudnn tensor descriptor.
template <typename Vec>
inline cudnn_frontend::Tensor build_cudnn_tensor(
int64_t id,
const array& x,
const Vec& shape,
const Vec& strides) {
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
// Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
assert(shape.size() >= 3);
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(std::move(shape), std::move(strides));
}
auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(convert_vector<int64_t>(x.shape()), x.strides());
}
// Return available engines for a |op_graph|.
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = true) {
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
sources.push_back([](auto& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
});
if (use_fallback) {
sources.push_back([&backend_type](auto& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
});
}
auto configs =
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
// Take |engine_configs| and |op_graph| and find a working execution plans
// from them.
std::optional<cudnn_frontend::ExecutionPlan>
find_cudnn_plan_from_engine_configs(
cudnnHandle_t handle,
const cudnn_frontend::EngineConfigList& engine_configs,
const cudnn_frontend::OperationGraph& op_graph) {
auto op_graph_tag = op_graph.getTag();
for (const auto& config : engine_configs) {
try {
return cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return std::nullopt;
}
// Prepare workspace and args to execute plan.
template <typename F>
bool prepare_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
array workspace(
workspace_size > 0 ? allocator::malloc(workspace_size)
: allocator::Buffer(nullptr),
{workspace_size},
uint8);
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
return false;
}
encoder.add_temporary(workspace);
return true;
}
} // namespace
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, x.strides());
}
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return build_cudnn_tensor(id, x, shape, strides);
}
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
if (x.ndim() == 0) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
}
if (x.ndim() == 1) {
int64_t s = x.shape(0);
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
SmallVector<int64_t, 4> strides = {s, 1, s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 2) {
int64_t s = x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 3 || x.ndim() == 4) {
return build_cudnn_tensor_nchw(id, x);
}
throw std::runtime_error(
fmt::format("Unsupported array with {} dims.", x.ndim()));
}
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return cudnn_frontend::TensorBuilder()
.setDim(scalar_dims.size(), scalar_dims.data())
.setStrides(scalar_dims.size(), scalar_dims.data())
.setId(id)
.setAlignment(16)
.setDataType(dtype_to_cudnn_type(dtype))
.setByValue(true)
.build();
}
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
auto capture = encoder.capture_context();
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
return true;
});
}
#if CUDNN_VERSION >= 90500
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
if (!graph) {
graph = CudaGraph(encoder.device());
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
} else {
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
}
encoder.add_graph_node(graph);
return true;
});
}
#endif
} // namespace mlx::core

View File

@@ -0,0 +1,164 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <algorithm>
#include <array>
namespace mlx::core {
namespace cu {
class CommandEncoder;
}
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
// Convert the type of elements in |vec| to |T|.
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
}
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
//
// There are 2 differences from the const_param util from kernel_utils.cuh:
// 1. The rest of array is filled with 0.
// 2. This util can be used in .cpp files.
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
inline void* get_data_ptr(T& scalar) {
return &scalar;
}
// Return an array filled with data pointers of args.
template <typename... Args>
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
return {get_data_ptr(args)...};
}
// Map dtype to cudnn data type.
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
// Create a tensor descriptor from |x|.
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
// from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
// Create a 4D scalar tensor descriptor, which is passed by value.
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
// Find a working plan for |op_graph|.
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph);
// Encode the plan to command buffer by capturing.
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs);
#if CUDNN_VERSION >= 90500
// Encode the plan to command buffer by using native graph api of cudnn. If the
// |graph| is empty it will be populated, otherwise it will be updated.
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs);
#endif
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_capturing(
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
}
#if CUDNN_VERSION >= 90500
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_graph_api(
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
}
#endif
} // namespace mlx::core

View File

@@ -91,9 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
}
CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
graph.end_capture(enc.stream());
if (discard) {
return;
}
@@ -185,9 +183,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
}
CommandEncoder::CommandEncoder(Device& d)
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
: device_(d),
stream_(d),
graph_(d),
graph_cache_(cuda_graph_cache_size()) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
@@ -311,8 +310,7 @@ void CommandEncoder::commit() {
to_nodes_.clear();
graph_key_.clear();
node_map_.clear();
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
graph_ = CudaGraph(device_);
}
// Put completion handlers in a batch.

View File

@@ -21,7 +21,7 @@ class CommandEncoder {
struct CaptureContext {
CaptureContext(CommandEncoder& enc);
~CaptureContext();
cudaGraph_t graph;
CudaGraph graph;
CommandEncoder& enc;
bool discard{false};
};
@@ -115,7 +115,7 @@ class CommandEncoder {
Device& device_;
CudaStream stream_;
cudaGraph_t graph_;
CudaGraph graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};

View File

@@ -202,6 +202,25 @@ CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void CublasGemm::set_out(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype),
rows,
cols,
transposed,
ld,
batch_count,
batch_stride);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,

View File

@@ -44,6 +44,17 @@ class CublasGemm {
~CublasGemm();
// The output's descriptor is inferred from inputs by default, use this method
// for unusual output.
void set_out(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride);
void run(
cu::CommandEncoder& encoder,
array& out,

View File

@@ -8,36 +8,6 @@
namespace mlx::core {
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
}
CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {}
CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) {
other.handle_ = nullptr;
};
CudaGraphExec::~CudaGraphExec() {
reset();
}
void CudaGraphExec::instantiate(cudaGraph_t graph) {
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
}
void CudaGraphExec::reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_));
handle_ = nullptr;
}
}
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
@@ -96,4 +66,24 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
}
}
CudaGraph::CudaGraph(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0));
}
void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
}
void CudaGraphExec::instantiate(cudaGraph_t graph) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
}
CudaStream::CudaStream(cu::Device& device) {
device.make_current();
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking));
}
} // namespace mlx::core

View File

@@ -16,44 +16,6 @@ class Device;
struct Dtype;
// Cuda stream managed with RAII.
class CudaStream {
public:
explicit CudaStream(cu::Device& device);
~CudaStream();
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const {
return stream_;
}
private:
cudaStream_t stream_;
};
// Move-able RAII handle of cudaGraphExec_t.
class CudaGraphExec {
public:
CudaGraphExec(cudaGraphExec_t handle = nullptr);
CudaGraphExec(CudaGraphExec&& other);
~CudaGraphExec();
CudaGraphExec(const CudaGraphExec&) = delete;
CudaGraphExec& operator=(const CudaGraphExec&) = delete;
void instantiate(cudaGraph_t graph);
void reset();
operator cudaGraphExec_t() const {
return handle_;
}
private:
cudaGraphExec_t handle_;
};
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
@@ -66,4 +28,62 @@ void check_cuda_error(const char* name, CUresult err);
// Convert Dtype to CUDA C++ types.
const char* dtype_to_cuda_type(const Dtype& dtype);
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
} // namespace mlx::core

View File

@@ -15,14 +15,6 @@ cuda_skip = {
"TestOps.test_hadamard_grad_vmap",
# Convolutions NYI
"TestConv.test_1d_conv_with_2d",
"TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad",
"TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general",
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
# FFTs NYI
"TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two",