mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add fp8 e4m3 converters
This commit is contained in:
@@ -4032,3 +4032,25 @@ TEST_CASE("test conv_transpose3d with output_padding") {
|
||||
{1, 2, 4, 4, 1});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test fp8 conversion") {
|
||||
for (auto t : {float32, float16, bfloat16}) {
|
||||
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t);
|
||||
auto in_fp8 = to_fp8(in);
|
||||
auto out = from_fp8(in_fp8, t);
|
||||
}
|
||||
|
||||
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});
|
||||
array noisy_in({-1.135, -1.01, 0.0001, 1.01, 1.135, 4.6, 447.0});
|
||||
auto in_fp8 = to_fp8(noisy_in);
|
||||
auto out = from_fp8(in_fp8, float32);
|
||||
CHECK(array_equal(out, in).item<bool>());
|
||||
|
||||
// Overflow
|
||||
in = array({-600.0, 600.0});
|
||||
in_fp8 = to_fp8(in);
|
||||
out = from_fp8(in_fp8, float32);
|
||||
|
||||
auto expected = array({NAN, NAN});
|
||||
CHECK(array_equal(out, expected, true).item<bool>());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user