mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-30 22:51:24 +08:00
fp16 matmul fix + tf32 env var
This commit is contained in:
parent
c353af5998
commit
3110982b0e
@ -5,6 +5,7 @@
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <fmt/format.h>
|
||||
@ -45,7 +46,7 @@ class MatMul {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto scale_type = dtype_to_cuda_type(dtype);
|
||||
if (dtype == bfloat16) {
|
||||
if (dtype == bfloat16 || dtype == float16) {
|
||||
scale_type = CUDA_R_32F;
|
||||
}
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||
@ -192,11 +193,12 @@ class MatMul {
|
||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case float16:
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case bfloat16:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case float32:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
: CUBLAS_COMPUTE_32F;
|
||||
case float64:
|
||||
case complex64:
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
|
@ -149,6 +149,11 @@ inline bool metal_fast_synch() {
|
||||
return metal_fast_synch;
|
||||
}
|
||||
|
||||
inline bool enable_tf32() {
|
||||
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
|
||||
return enable_tf32_;
|
||||
}
|
||||
|
||||
} // namespace env
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user