diff --git a/mlx/backend/cuda/quantized/cuda_fp4.h b/mlx/backend/cuda/quantized/cuda_fp4.h index 73943bb36..c107a38c7 100644 --- a/mlx/backend/cuda/quantized/cuda_fp4.h +++ b/mlx/backend/cuda/quantized/cuda_fp4.h @@ -1,5 +1,18 @@ #pragma once +struct __nv_fp8_e8m0 { + __device__ __nv_fp8_e8m0(uint8_t x) : __x(x) {} + + __device__ operator float() { + if (__x == 0xFF) { + return std::numeric_limits::quiet_NaN(); + } + return std::ldexp(1.0f, static_cast(__x) - 127); + } + + uint8_t __x{0}; +}; + struct __nv_fp4_e2m1 { __device__ __nv_fp4_e2m1(float x) { if (std::isnan(x)) {