From a6d780154f2fe79e893045659d17fbace243802a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 13 Jun 2025 22:10:46 -0700 Subject: [PATCH] fix cuda gemm for bf16 (#2288) --- mlx/backend/cuda/matmul.cpp | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 89247fd3e..9930c75b8 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -44,9 +44,12 @@ class MatMul { int64_t b_batch_stride) { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; - auto type = dtype_to_cuda_type(dtype); + auto scale_type = dtype_to_cuda_type(dtype); + if (dtype == bfloat16) { + scale_type = CUDA_R_32F; + } CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( - &matmul_desc_, dtype_to_compute_type(dtype), type)); + &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, @@ -65,6 +68,7 @@ class MatMul { &op, sizeof(cublasOperation_t))); + auto type = dtype_to_cuda_type(dtype); a_desc_ = create_matrix_layout( type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); b_desc_ = create_matrix_layout( @@ -187,15 +191,10 @@ class MatMul { private: cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { - case uint8: - case uint16: - case int8: - case int16: - case int32: - return CUBLAS_COMPUTE_32I; case float16: - case bfloat16: return CUBLAS_COMPUTE_16F; + case bfloat16: + return CUBLAS_COMPUTE_32F; case float32: return CUBLAS_COMPUTE_32F; case float64: @@ -209,16 +208,6 @@ class MatMul { cudaDataType_t dtype_to_cuda_type(Dtype dtype) { switch (dtype) { - case uint8: - return CUDA_R_8U; - case uint16: - return CUDA_R_16U; - case int8: - return CUDA_R_8I; - case int16: - return CUDA_R_16I; - case int32: - return CUDA_R_32I; case float16: return CUDA_R_16F; case bfloat16: