diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index c00444c98..e30b86d83 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -138,8 +138,8 @@ struct ToFP8 { auto result_high = Simd(f_bits_high >> 20); result = select(f_bits < (121 << 23), result_low, result_high); - auto result_nan = Simd(0x7f); - result = select(f_bits >= fp8_max, result_nan, result); + auto result_sat = Simd(fp8_max); + result = select(f_bits >= fp8_max, result_sat, result); return result | Simd(sign >> 24); } diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index c1887ce0e..c3a89892e 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -458,8 +458,8 @@ struct ToFP8 { uint32_t sign = f_bits & 0x80000000; f_bits ^= sign; if (f_bits >= fp8_max) { - // NaN - all exponent and mantissa bits set to 1 - result = 0x7f; + // Default behavior saturates to min/max + result = fp8_max; } else { if (f_bits < (121 << 23)) { f_bits = diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 2b787e63a..3990b12d7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -4051,6 +4051,6 @@ TEST_CASE("test fp8 conversion") { in_fp8 = to_fp8(in); out = from_fp8(in_fp8, float32); - auto expected = array({NAN, NAN}); + auto expected = array({-448.0f, 448.0f}); CHECK(array_equal(out, expected, true).item()); }