mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
Fp8 conversion (#2686)
* add fp8 e4m3 converters * add cuda * default saturate to min/max * fix for older OS * fix no gpu/cpu * fix saturate * fix compile
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
// Required for using M_PI_2 in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
@@ -4030,3 +4029,26 @@ TEST_CASE("test conv_transpose3d with output_padding") {
|
||||
{1, 2, 4, 4, 1});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test fp8 conversion") {
|
||||
for (auto t : {float32, float16, bfloat16}) {
|
||||
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t);
|
||||
auto in_fp8 = to_fp8(in);
|
||||
auto out = from_fp8(in_fp8, t);
|
||||
CHECK(array_equal(out, in).item<bool>());
|
||||
}
|
||||
|
||||
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});
|
||||
array noisy_in({-1.135, -1.01, 0.0001, 1.01, 1.135, 4.6, 447.0});
|
||||
auto in_fp8 = to_fp8(noisy_in);
|
||||
auto out = from_fp8(in_fp8, float32);
|
||||
CHECK(array_equal(out, in).item<bool>());
|
||||
|
||||
// Overflow
|
||||
in = array({-600.0, 600.0});
|
||||
in_fp8 = to_fp8(in);
|
||||
out = from_fp8(in_fp8, float32);
|
||||
|
||||
auto expected = array({-448.0f, 448.0f});
|
||||
CHECK(array_equal(out, expected, true).item<bool>());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user