mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
default saturate to min/max
This commit is contained in:
@@ -138,8 +138,8 @@ struct ToFP8 {
|
||||
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
|
||||
result = select(f_bits < (121 << 23), result_low, result_high);
|
||||
|
||||
auto result_nan = Simd<uint8_t, N>(0x7f);
|
||||
result = select(f_bits >= fp8_max, result_nan, result);
|
||||
auto result_sat = Simd<uint8_t, N>(fp8_max);
|
||||
result = select(f_bits >= fp8_max, result_sat, result);
|
||||
return result | Simd<uint8_t, N>(sign >> 24);
|
||||
}
|
||||
|
||||
|
||||
@@ -458,8 +458,8 @@ struct ToFP8 {
|
||||
uint32_t sign = f_bits & 0x80000000;
|
||||
f_bits ^= sign;
|
||||
if (f_bits >= fp8_max) {
|
||||
// NaN - all exponent and mantissa bits set to 1
|
||||
result = 0x7f;
|
||||
// Default behavior saturates to min/max
|
||||
result = fp8_max;
|
||||
} else {
|
||||
if (f_bits < (121 << 23)) {
|
||||
f_bits =
|
||||
|
||||
@@ -4051,6 +4051,6 @@ TEST_CASE("test fp8 conversion") {
|
||||
in_fp8 = to_fp8(in);
|
||||
out = from_fp8(in_fp8, float32);
|
||||
|
||||
auto expected = array({NAN, NAN});
|
||||
auto expected = array({-448.0f, 448.0f});
|
||||
CHECK(array_equal(out, expected, true).item<bool>());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user