This commit is contained in:
Awni Hannun
2025-10-23 12:24:56 -07:00
parent 5c051678a3
commit 7b34dc366e
2 changed files with 12 additions and 12 deletions

View File

@@ -3,29 +3,29 @@
struct __nv_fp4_e2m1 {
__device__ __nv_fp4_e2m1(float x) {
if (std::isnan(x)) {
__x = 0x7;
return;
__x = 0x7;
return;
}
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
x = std::abs(x);
if (x > 5.0f) {
__x = 0x7;
__x = 0x7;
} else if (x >= 3.5f) {
__x = 0x6;
__x = 0x6;
} else if (x > 2.5f) {
__x = 0x5;
__x = 0x5;
} else if (x >= 1.75f) {
__x = 0x4;
__x = 0x4;
} else if (x > 1.25f) {
__x = 0x3;
__x = 0x3;
} else if (x >= 0.75f) {
__x = 0x2;
__x = 0x2;
} else if (x > 0.25f) {
__x = 0x1;
__x = 0x1;
} else {
__x = 0x0;
__x = 0x0;
}
__x |= sign_bit;
}
@@ -47,8 +47,7 @@ struct __nv_fp4_e2m1 {
-2.0f,
-3.0f,
-4.0f,
-6.0f
};
-6.0f};
return LUT[__x];
}

View File

@@ -8,6 +8,7 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda_fp4.h>
#include <cuda_fp8.h>
namespace mlx::core {
namespace cu {