mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
format (#2700)
This commit is contained in:
@@ -3,29 +3,29 @@
|
|||||||
struct __nv_fp4_e2m1 {
|
struct __nv_fp4_e2m1 {
|
||||||
__device__ __nv_fp4_e2m1(float x) {
|
__device__ __nv_fp4_e2m1(float x) {
|
||||||
if (std::isnan(x)) {
|
if (std::isnan(x)) {
|
||||||
__x = 0x7;
|
__x = 0x7;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
|
const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
|
||||||
x = std::abs(x);
|
x = std::abs(x);
|
||||||
|
|
||||||
if (x > 5.0f) {
|
if (x > 5.0f) {
|
||||||
__x = 0x7;
|
__x = 0x7;
|
||||||
} else if (x >= 3.5f) {
|
} else if (x >= 3.5f) {
|
||||||
__x = 0x6;
|
__x = 0x6;
|
||||||
} else if (x > 2.5f) {
|
} else if (x > 2.5f) {
|
||||||
__x = 0x5;
|
__x = 0x5;
|
||||||
} else if (x >= 1.75f) {
|
} else if (x >= 1.75f) {
|
||||||
__x = 0x4;
|
__x = 0x4;
|
||||||
} else if (x > 1.25f) {
|
} else if (x > 1.25f) {
|
||||||
__x = 0x3;
|
__x = 0x3;
|
||||||
} else if (x >= 0.75f) {
|
} else if (x >= 0.75f) {
|
||||||
__x = 0x2;
|
__x = 0x2;
|
||||||
} else if (x > 0.25f) {
|
} else if (x > 0.25f) {
|
||||||
__x = 0x1;
|
__x = 0x1;
|
||||||
} else {
|
} else {
|
||||||
__x = 0x0;
|
__x = 0x0;
|
||||||
}
|
}
|
||||||
__x |= sign_bit;
|
__x |= sign_bit;
|
||||||
}
|
}
|
||||||
@@ -47,8 +47,7 @@ struct __nv_fp4_e2m1 {
|
|||||||
-2.0f,
|
-2.0f,
|
||||||
-3.0f,
|
-3.0f,
|
||||||
-4.0f,
|
-4.0f,
|
||||||
-6.0f
|
-6.0f};
|
||||||
};
|
|
||||||
|
|
||||||
return LUT[__x];
|
return LUT[__x];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
#include <cuda_fp4.h>
|
#include <cuda_fp4.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|||||||
Reference in New Issue
Block a user