diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 9930c75b80..5a5e6182e8 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -5,6 +5,7 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" +#include "mlx/utils.h" #include #include @@ -45,7 +46,7 @@ class MatMul { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; auto scale_type = dtype_to_cuda_type(dtype); - if (dtype == bfloat16) { + if (dtype == bfloat16 || dtype == float16) { scale_type = CUDA_R_32F; } CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( @@ -192,11 +193,12 @@ class MatMul { cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { case float16: - return CUBLAS_COMPUTE_16F; + return CUBLAS_COMPUTE_32F; case bfloat16: return CUBLAS_COMPUTE_32F; case float32: - return CUBLAS_COMPUTE_32F; + return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 + : CUBLAS_COMPUTE_32F; case float64: case complex64: return CUBLAS_COMPUTE_64F; diff --git a/mlx/utils.h b/mlx/utils.h index f0aa7c2de2..f16bf0468d 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -149,6 +149,11 @@ inline bool metal_fast_synch() { return metal_fast_synch; } +inline bool enable_tf32() { + static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1); + return enable_tf32_; +} + } // namespace env } // namespace mlx::core