fp16 matmul fix + tf32 env var

This commit is contained in:
Awni Hannun 2025-06-14 07:17:04 -07:00
parent c353af5998
commit 3110982b0e
2 changed files with 10 additions and 3 deletions

View File

@ -5,6 +5,7 @@
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h"
#include <cublasLt.h> #include <cublasLt.h>
#include <fmt/format.h> #include <fmt/format.h>
@ -45,7 +46,7 @@ class MatMul {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype); auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16) { if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F; scale_type = CUDA_R_32F;
} }
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
@ -192,11 +193,12 @@ class MatMul {
cublasComputeType_t dtype_to_compute_type(Dtype dtype) { cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) { switch (dtype) {
case float16: case float16:
return CUBLAS_COMPUTE_16F; return CUBLAS_COMPUTE_32F;
case bfloat16: case bfloat16:
return CUBLAS_COMPUTE_32F; return CUBLAS_COMPUTE_32F;
case float32: case float32:
return CUBLAS_COMPUTE_32F; return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64: case float64:
case complex64: case complex64:
return CUBLAS_COMPUTE_64F; return CUBLAS_COMPUTE_64F;

View File

@ -149,6 +149,11 @@ inline bool metal_fast_synch() {
return metal_fast_synch; return metal_fast_synch;
} }
inline bool enable_tf32() {
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
return enable_tf32_;
}
} // namespace env } // namespace env
} // namespace mlx::core } // namespace mlx::core