mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix cuda gemm for bf16 (#2288)
This commit is contained in:
parent
6871e2eeb7
commit
a6d780154f
@ -44,9 +44,12 @@ class MatMul {
|
|||||||
int64_t b_batch_stride) {
|
int64_t b_batch_stride) {
|
||||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||||
|
|
||||||
auto type = dtype_to_cuda_type(dtype);
|
auto scale_type = dtype_to_cuda_type(dtype);
|
||||||
|
if (dtype == bfloat16) {
|
||||||
|
scale_type = CUDA_R_32F;
|
||||||
|
}
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||||
&matmul_desc_, dtype_to_compute_type(dtype), type));
|
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
||||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
@ -65,6 +68,7 @@ class MatMul {
|
|||||||
&op,
|
&op,
|
||||||
sizeof(cublasOperation_t)));
|
sizeof(cublasOperation_t)));
|
||||||
|
|
||||||
|
auto type = dtype_to_cuda_type(dtype);
|
||||||
a_desc_ = create_matrix_layout(
|
a_desc_ = create_matrix_layout(
|
||||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||||
b_desc_ = create_matrix_layout(
|
b_desc_ = create_matrix_layout(
|
||||||
@ -187,15 +191,10 @@ class MatMul {
|
|||||||
private:
|
private:
|
||||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case uint8:
|
|
||||||
case uint16:
|
|
||||||
case int8:
|
|
||||||
case int16:
|
|
||||||
case int32:
|
|
||||||
return CUBLAS_COMPUTE_32I;
|
|
||||||
case float16:
|
case float16:
|
||||||
case bfloat16:
|
|
||||||
return CUBLAS_COMPUTE_16F;
|
return CUBLAS_COMPUTE_16F;
|
||||||
|
case bfloat16:
|
||||||
|
return CUBLAS_COMPUTE_32F;
|
||||||
case float32:
|
case float32:
|
||||||
return CUBLAS_COMPUTE_32F;
|
return CUBLAS_COMPUTE_32F;
|
||||||
case float64:
|
case float64:
|
||||||
@ -209,16 +208,6 @@ class MatMul {
|
|||||||
|
|
||||||
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case uint8:
|
|
||||||
return CUDA_R_8U;
|
|
||||||
case uint16:
|
|
||||||
return CUDA_R_16U;
|
|
||||||
case int8:
|
|
||||||
return CUDA_R_8I;
|
|
||||||
case int16:
|
|
||||||
return CUDA_R_16I;
|
|
||||||
case int32:
|
|
||||||
return CUDA_R_32I;
|
|
||||||
case float16:
|
case float16:
|
||||||
return CUDA_R_16F;
|
return CUDA_R_16F;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
|
Loading…
Reference in New Issue
Block a user