diff --git a/mlx/backend/cuda/quantized/cuda_fp4.h b/mlx/backend/cuda/quantized/cuda_fp4.h index abda0df70..73943bb36 100644 --- a/mlx/backend/cuda/quantized/cuda_fp4.h +++ b/mlx/backend/cuda/quantized/cuda_fp4.h @@ -3,29 +3,29 @@ struct __nv_fp4_e2m1 { __device__ __nv_fp4_e2m1(float x) { if (std::isnan(x)) { - __x = 0x7; - return; + __x = 0x7; + return; } const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0; x = std::abs(x); if (x > 5.0f) { - __x = 0x7; + __x = 0x7; } else if (x >= 3.5f) { - __x = 0x6; + __x = 0x6; } else if (x > 2.5f) { - __x = 0x5; + __x = 0x5; } else if (x >= 1.75f) { - __x = 0x4; + __x = 0x4; } else if (x > 1.25f) { - __x = 0x3; + __x = 0x3; } else if (x >= 0.75f) { - __x = 0x2; + __x = 0x2; } else if (x > 0.25f) { - __x = 0x1; + __x = 0x1; } else { - __x = 0x0; + __x = 0x0; } __x |= sign_bit; } @@ -47,8 +47,7 @@ struct __nv_fp4_e2m1 { -2.0f, -3.0f, -4.0f, - -6.0f - }; + -6.0f}; return LUT[__x]; } diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 70b125522..bc50ffef9 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -8,6 +8,7 @@ #include #include #include +#include namespace mlx::core { namespace cu {