mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
format (#2700)
This commit is contained in:
@@ -3,29 +3,29 @@
|
||||
struct __nv_fp4_e2m1 {
|
||||
__device__ __nv_fp4_e2m1(float x) {
|
||||
if (std::isnan(x)) {
|
||||
__x = 0x7;
|
||||
return;
|
||||
__x = 0x7;
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
|
||||
x = std::abs(x);
|
||||
|
||||
if (x > 5.0f) {
|
||||
__x = 0x7;
|
||||
__x = 0x7;
|
||||
} else if (x >= 3.5f) {
|
||||
__x = 0x6;
|
||||
__x = 0x6;
|
||||
} else if (x > 2.5f) {
|
||||
__x = 0x5;
|
||||
__x = 0x5;
|
||||
} else if (x >= 1.75f) {
|
||||
__x = 0x4;
|
||||
__x = 0x4;
|
||||
} else if (x > 1.25f) {
|
||||
__x = 0x3;
|
||||
__x = 0x3;
|
||||
} else if (x >= 0.75f) {
|
||||
__x = 0x2;
|
||||
__x = 0x2;
|
||||
} else if (x > 0.25f) {
|
||||
__x = 0x1;
|
||||
__x = 0x1;
|
||||
} else {
|
||||
__x = 0x0;
|
||||
__x = 0x0;
|
||||
}
|
||||
__x |= sign_bit;
|
||||
}
|
||||
@@ -47,8 +47,7 @@ struct __nv_fp4_e2m1 {
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f
|
||||
};
|
||||
-6.0f};
|
||||
|
||||
return LUT[__x];
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cuda_fp4.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace cu {
|
||||
|
||||
Reference in New Issue
Block a user