mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
57 lines
882 B
C++
57 lines
882 B
C++
#pragma once
|
|
|
|
constexpr constant static float FP4_LUT[16] = {
|
|
+0.0f,
|
|
+0.5f,
|
|
+1.0f,
|
|
+1.5f,
|
|
+2.0f,
|
|
+3.0f,
|
|
+4.0f,
|
|
+6.0f,
|
|
-0.0f,
|
|
-0.5f,
|
|
-1.0f,
|
|
-1.5f,
|
|
-2.0f,
|
|
-3.0f,
|
|
-4.0f,
|
|
-6.0f};
|
|
|
|
struct fp4_e2m1 {
|
|
fp4_e2m1(float x) {
|
|
if (metal::isnan(x)) {
|
|
bits = 0x7;
|
|
return;
|
|
}
|
|
|
|
const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;
|
|
x = metal::abs(x);
|
|
|
|
if (x > 5.0f) {
|
|
bits = 0x7;
|
|
} else if (x >= 3.5f) {
|
|
bits = 0x6;
|
|
} else if (x > 2.5f) {
|
|
bits = 0x5;
|
|
} else if (x >= 1.75f) {
|
|
bits = 0x4;
|
|
} else if (x > 1.25f) {
|
|
bits = 0x3;
|
|
} else if (x >= 0.75f) {
|
|
bits = 0x2;
|
|
} else if (x > 0.25f) {
|
|
bits = 0x1;
|
|
} else {
|
|
bits = 0x0;
|
|
}
|
|
bits |= sign_bit;
|
|
}
|
|
|
|
operator float() {
|
|
return FP4_LUT[bits];
|
|
}
|
|
|
|
uint8_t bits;
|
|
};
|