fix cuda gemm for bf16 (#2288)

This commit is contained in:
Awni Hannun 2025-06-13 22:10:46 -07:00 committed by GitHub
parent 6871e2eeb7
commit a6d780154f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -44,9 +44,12 @@ class MatMul {
int64_t b_batch_stride) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto type = dtype_to_cuda_type(dtype);
auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), type));
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
@ -65,6 +68,7 @@ class MatMul {
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cuda_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
@ -187,15 +191,10 @@ class MatMul {
private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case uint8:
case uint16:
case int8:
case int16:
case int32:
return CUBLAS_COMPUTE_32I;
case float16:
case bfloat16:
return CUBLAS_COMPUTE_16F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return CUBLAS_COMPUTE_32F;
case float64:
@ -209,16 +208,6 @@ class MatMul {
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) {
case uint8:
return CUDA_R_8U;
case uint16:
return CUDA_R_16U;
case int8:
return CUDA_R_8I;
case int16:
return CUDA_R_16I;
case int32:
return CUDA_R_32I;
case float16:
return CUDA_R_16F;
case bfloat16: