[CUDA] Add GEMM-based fallback convolution kernels (#2511)

* Add gemm_conv

* Add gemm_grouped_conv
This commit is contained in:
Cheng
2025-08-20 10:06:22 +09:00
committed by GitHub
parent 65d0d40232
commit ac85ddfdb7
8 changed files with 667 additions and 32 deletions

View File

@@ -202,6 +202,25 @@ CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void CublasGemm::set_out(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype),
rows,
cols,
transposed,
ld,
batch_count,
batch_stride);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,

View File

@@ -44,6 +44,17 @@ class CublasGemm {
~CublasGemm();
// The output's descriptor is inferred from inputs by default, use this method
// for unusual output.
void set_out(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride);
void run(
cu::CommandEncoder& encoder,
array& out,