Fp8 conversion (#2686)

* add fp8 e4m3 converters

* add cuda

* default saturate to min/max

* fix for older OS

* fix no gpu/cpu

* fix saturate

* fix compile
This commit is contained in:
Awni Hannun
2025-10-27 16:35:50 -07:00
committed by GitHub
parent d1e06117e8
commit 969924cc69
23 changed files with 363 additions and 117 deletions

View File

@@ -2,7 +2,6 @@
// Required for using M_PI_2 in MSVC.
#define _USE_MATH_DEFINES
#include <cmath>
#include <numeric>
@@ -4030,3 +4029,26 @@ TEST_CASE("test conv_transpose3d with output_padding") {
{1, 2, 4, 4, 1});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test fp8 conversion") {
for (auto t : {float32, float16, bfloat16}) {
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t);
auto in_fp8 = to_fp8(in);
auto out = from_fp8(in_fp8, t);
CHECK(array_equal(out, in).item<bool>());
}
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});
array noisy_in({-1.135, -1.01, 0.0001, 1.01, 1.135, 4.6, 447.0});
auto in_fp8 = to_fp8(noisy_in);
auto out = from_fp8(in_fp8, float32);
CHECK(array_equal(out, in).item<bool>());
// Overflow
in = array({-600.0, 600.0});
in_fp8 = to_fp8(in);
out = from_fp8(in_fp8, float32);
auto expected = array({-448.0f, 448.0f});
CHECK(array_equal(out, expected, true).item<bool>());
}