mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Add GEMM-based fallback convolution kernels (#2511)
* Add gemm_conv * Add gemm_grouped_conv
This commit is contained in:
@@ -202,6 +202,25 @@ CublasGemm::~CublasGemm() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||
}
|
||||
|
||||
void CublasGemm::set_out(
|
||||
Dtype dtype,
|
||||
bool transposed,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
int64_t ld,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||
out_desc_ = create_matrix_layout(
|
||||
dtype_to_cublas_type(dtype),
|
||||
rows,
|
||||
cols,
|
||||
transposed,
|
||||
ld,
|
||||
batch_count,
|
||||
batch_stride);
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
|
||||
@@ -44,6 +44,17 @@ class CublasGemm {
|
||||
|
||||
~CublasGemm();
|
||||
|
||||
// The output's descriptor is inferred from inputs by default, use this method
|
||||
// for unusual output.
|
||||
void set_out(
|
||||
Dtype dtype,
|
||||
bool transposed,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
int64_t ld,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
|
||||
Reference in New Issue
Block a user