Faster complex matmul (#2571)

This commit is contained in:
Daniel Yeh
2025-10-03 08:33:15 +02:00
committed by GitHub
parent 287c63a093
commit 22a5da76c8
20 changed files with 623 additions and 73 deletions

View File

@@ -50,8 +50,10 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
case complex64:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
@@ -126,12 +128,13 @@ CublasGemm::CublasGemm(
N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cublas_type(dtype);
scale_type_ = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
scale_type_ = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
&matmul_desc_, dtype_to_compute_type(dtype), scale_type_));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
@@ -352,6 +355,16 @@ void CublasGemm::execute(
}
}
const void* alpha_ptr = α
const void* beta_ptr = β
complex64_t alpha_c, beta_c;
if (scale_type_ == CUDA_C_32F) {
alpha_c = complex64_t{alpha, 0.0f};
beta_c = complex64_t{beta, 0.0f};
alpha_ptr = &alpha_c;
beta_ptr = &beta_c;
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
@@ -368,12 +381,12 @@ void CublasGemm::execute(
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
&alpha,
alpha_ptr,
b, // a and b are swapped
a_desc_,
a,
b_desc_,
&beta,
beta_ptr,
c ? c : out,
c ? c_desc_ : out_desc_,
out,

View File

@@ -115,6 +115,7 @@ class CublasGemm {
uint64_t M_;
uint64_t N_;
cudaDataType_t scale_type_;
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};

View File

@@ -13,6 +13,37 @@ namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8;
// Accumulator type selection per input element type T.
template <typename T>
struct GemvAccType {
using type = T;
};
template <>
struct GemvAccType<__half> {
using type = float;
};
template <>
struct GemvAccType<__nv_bfloat16> {
using type = float;
};
template <>
struct GemvAccType<float> {
using type = float;
};
template <>
struct GemvAccType<double> {
using type = double;
};
template <>
struct GemvAccType<cu::complex64_t> {
using type = cu::complex64_t;
};
template <typename T, int rows_per_block, int n_per_thread>
__device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
@@ -24,7 +55,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) {
float sum = 0.0f;
using Acc = typename GemvAccType<T>::type;
Acc sum = Acc(0);
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat =
@@ -32,12 +64,11 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);
}
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
sum = cg::reduce(warp, sum, cg::plus<Acc>{});
if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum);
}
@@ -107,7 +138,7 @@ void gemv(
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat;