Fp8 conversion (#2686)

* add fp8 e4m3 converters

* add cuda

* default saturate to min/max

* fix for older OS

* fix no gpu/cpu

* fix saturate

* fix compile
This commit is contained in:
Awni Hannun
2025-10-27 16:35:50 -07:00
committed by GitHub
parent d1e06117e8
commit 969924cc69
23 changed files with 363 additions and 117 deletions

View File

@@ -88,6 +88,11 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()

View File

@@ -1,8 +1,11 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -1102,4 +1105,44 @@ void fast::Quantize::eval_cpu(
});
}
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
}
});
}
} // namespace mlx::core

View File

@@ -1,5 +1,6 @@
#pragma once
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
@@ -200,6 +201,15 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);

View File

@@ -171,6 +171,11 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;

View File

@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
size -= N;
src += N;
dst += N;

View File

@@ -108,4 +108,73 @@ struct Square {
SINGLE()
};
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 1087 << 20;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail

View File

@@ -52,6 +52,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
@@ -170,11 +171,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -2,6 +2,8 @@
#pragma once
#include <cuda_fp8.h>
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -334,4 +336,17 @@ struct Tanh {
}
};
struct ToFP8 {
template <typename T>
__device__ uint8_t operator()(T x) {
return __nv_fp8_e4m3(x).__x;
}
};
struct FromFP8 {
__device__ float operator()(uint8_t x) {
return float(*(__nv_fp8_e4m3*)(&x));
}
};
} // namespace mlx::core::cu

View File

@@ -0,0 +1,19 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
#include "mlx/fast_primitives.h"
namespace mlx::core {
void fast::ConvertFP8::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
auto& in = inputs[0];
auto& out = outputs[0];
auto& s = out.primitive().stream();
if (to_fp8_) {
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
} else {
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -108,6 +108,12 @@ constexpr bool supports_unary_op() {
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, ToFP8>) {
return std::is_same_v<Out, uint8_t> && is_floating_v<In>;
}
if (std::is_same_v<Op, FromFP8>) {
return std::is_same_v<In, uint8_t> && is_floating_v<Out>;
}
return false;
}

View File

@@ -9,11 +9,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
out[index + i] = Op()(in[index + i]);
out[index + i] = static_cast<U>(Op()(in[index + i]));
}
} else {
for (int i = 0; i < N; ++i) {
out[index + i] = Op()(in[index + i]);
out[index + i] = static_cast<U>(Op()(in[index + i]));
}
}
}
@@ -28,11 +28,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
}
} else {
for (int i = 0; i < N; ++i) {
out[offset + i] = Op()(in[offset + i]);
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
}
}
}
@@ -57,7 +57,7 @@ template <
IdxT xstride = in_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
out[out_idx++] = Op()(in[idx]);
out[out_idx++] = static_cast<U>(Op()(in[idx]));
idx += xstride;
}
}

View File

@@ -103,4 +103,13 @@ instantiate_unary_base_same(Round, complex64, complex64_t)
instantiate_unary_base(Real, complex64, float32, complex64_t, float)
instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on
instantiate_unary_all_same(LogicalNot, bool_, bool)
instantiate_unary_all(ToFP8, float16, uint8, float16_t, uint8_t)
instantiate_unary_all(ToFP8, bfloat16, uint8, bfloat16_t, uint8_t)
instantiate_unary_all(ToFP8, float32, uint8, float, uint8_t)
instantiate_unary_all(FromFP8, uint8, float16, uint8_t, float16_t)
instantiate_unary_all(FromFP8, uint8, bfloat16, uint8_t, bfloat16_t)
instantiate_unary_all(FromFP8, uint8, float32, uint8_t, float)
// clang-format on

View File

@@ -225,8 +225,7 @@ struct Floor {
};
struct Imag {
template <typename T>
T operator()(T x) {
float operator()(complex64_t x) {
return x.imag;
};
};
@@ -290,8 +289,7 @@ struct Negative {
};
struct Real {
template <typename T>
T operator()(T x) {
float operator()(complex64_t x) {
return x.real;
};
};
@@ -440,3 +438,64 @@ complex64_t ArcTan::operator()(complex64_t x) {
auto ix = i * x;
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
};
inline float fp32_from_bits(uint32_t bits) {
return *(reinterpret_cast<thread float*>(&bits));
}
inline float fp32_to_bits(float x) {
return *(reinterpret_cast<thread uint32_t*>(&x));
}
struct ToFP8 {
template <typename T>
uint8_t operator()(T f) {
// From PyTorch
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
uint32_t fp8_max = 1087 << 20;
uint32_t denorm_mask = 141 << 23;
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
uint8_t result = 0u;
uint32_t sign = f_bits & 0x80000000;
f_bits ^= sign;
if (f_bits >= fp8_max) {
// Default behavior saturates to min/max
result = 0x7E;
} else {
if (f_bits < (121 << 23)) {
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 20);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
};
struct FromFP8 {
float operator()(uint8_t x) {
// From PyTorch:
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
uint32_t w = static_cast<uint32_t>(x) << 24;
uint32_t sign = w & 0x80000000;
uint32_t nonsign = w & 0x7FFFFFFF;
uint32_t renorm_shift = metal::clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
int32_t inf_nan_mask =
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
uint32_t result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
};

View File

@@ -6,6 +6,7 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
@@ -1108,4 +1109,12 @@ void fast::Quantize::eval_gpu(
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void fast::ConvertFP8::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
unary_op_gpu(inputs, out, name(), stream());
}
} // namespace mlx::core

View File

@@ -144,17 +144,7 @@ UNARY_GPU(Tan)
UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:
unary_op_gpu(inputs, out, name());
break;
case Base::two:
unary_op_gpu(inputs, out, name());
break;
case Base::ten:
unary_op_gpu(inputs, out, name());
break;
}
unary_op_gpu(inputs, out, name());
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -130,6 +130,7 @@ NO_CPU(View)
namespace fast {
NO_CPU_MULTI(Quantize)
NO_CPU_MULTI(ConvertFP8)
} // namespace fast
namespace distributed {

View File

@@ -154,6 +154,7 @@ NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(ConvertFP8)
NO_GPU_MULTI(Quantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast

View File

@@ -843,4 +843,9 @@ std::vector<Shape> Quantize::output_shapes(const std::vector<array>& inputs) {
}
}
bool ConvertFP8::is_equivalent(const Primitive& other) const {
const ConvertFP8& a_other = static_cast<const ConvertFP8&>(other);
return to_fp8_ == a_other.to_fp8_;
}
} // namespace mlx::core::fast

View File

@@ -250,6 +250,35 @@ class ScaledDotProductAttention : public Custom {
bool has_sinks_;
};
class ConvertFP8 : public Primitive {
public:
explicit ConvertFP8(Stream stream, bool to_fp8)
: Primitive(stream), to_fp8_(to_fp8) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
const char* name() const override {
if (to_fp8_) {
return "ToFP8";
} else {
return "FromFP8";
}
}
bool state() const {
return to_fp8_;
};
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
private:
bool to_fp8_;
};
class Quantize : public Custom {
public:
explicit Quantize(

View File

@@ -4,7 +4,6 @@
#include <memory>
#include <stack>
#include "mlx/fast.h"
#include "mlx/io.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
@@ -103,92 +102,6 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
}
}
array f8_e4m3_to_float(array x, Dtype dtype, StreamOrDevice s) {
if (to_stream(s).device == Device::gpu) {
// From PyTorch:
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
std::string source = R"(
uint elem = thread_position_in_grid.x;
uint8_t val = x[elem];
const uint32_t w = (uint32_t)val << 24;
const uint32_t sign = w & 0x80000000;
const uint32_t nonsign = w & 0x7FFFFFFF;
uint32_t renorm_shift = metal::clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
const int32_t inf_nan_mask =
((int32_t)(nonsign + 0x01000000) >> 8) & 0x7F800000;
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
uint32_t result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
float out = *(reinterpret_cast<thread float*>(&result));
y[elem] = static_cast<T>(out);
)";
auto kernel = fast::metal_kernel("f8_e4m3", {"x"}, {"y"}, source);
auto outputs = kernel(
{x},
{x.shape()},
{dtype},
{x.size(), 1, 1},
{256, 1, 1},
{{"T", dtype}},
std::nullopt,
false,
s);
return outputs[0];
} else {
auto w = left_shift(astype(x, uint32, s), array({24}, uint32), s);
auto sign = bitwise_and(w, array({0x80000000}, uint32), s);
auto nonsign = bitwise_and(w, array({0x7FFFFFFF}, uint32), s);
// Emulate a clz op with a lookup table
auto clz_table =
array({28, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, uint32);
auto renorm_shift = take(clz_table, bitwise_and(x, array({0xf}), s), s);
renorm_shift = where(
greater(
bitwise_and(x, array({0x70}, uint32), s), array({0}, uint32), s),
array({0}, uint32),
renorm_shift,
s);
auto inf_nan_mask = bitwise_and(
right_shift(
astype(add(nonsign, array(0x01000000, int32), s), int32, s),
array({8}, int32),
s),
array({0x7F800000}, int32),
s);
auto zero_mask = right_shift(
astype(subtract(nonsign, array({1}, uint32), s), int32, s),
array({31}, int32),
s);
zero_mask = astype(zero_mask, uint32, s);
inf_nan_mask = astype(inf_nan_mask, uint32, s);
auto result =
add(right_shift(
left_shift(nonsign, renorm_shift, s), array({4}, uint32), s),
left_shift(
subtract(array({0x78}, uint32), renorm_shift, s),
array({23}, uint32),
s),
s);
result = bitwise_or(
sign,
bitwise_and(
bitwise_or(result, inf_nan_mask, s),
bitwise_invert(zero_mask, s),
s),
s);
result = astype(view(result, float32, s), dtype, s);
return result;
}
}
/** Load array from reader in safetensor format */
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
@@ -244,7 +157,7 @@ SafetensorsLoad load_safetensors(
stream, in_stream, offset + data_offsets.at(0), false),
std::vector<array>{});
if (dtype == ST_F8_E4M3) {
loaded_array = f8_e4m3_to_float(loaded_array, bfloat16, s);
loaded_array = from_fp8(loaded_array, bfloat16, s);
}
res.insert({item.key(), loaded_array});
}

View File

@@ -4506,6 +4506,40 @@ array dequantize(
}
}
array from_fp8(array x, Dtype dtype, StreamOrDevice s) {
if (x.dtype() != uint8) {
std::ostringstream msg;
msg << "[from_fp8] Input must have type uint8 but "
<< "x.dtype() == " << x.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (!issubdtype(dtype, floating)) {
std::ostringstream msg;
msg << "[from_fp8] Only real floating types are supported but "
<< "dtype == " << dtype << ".";
throw std::invalid_argument(msg.str());
}
return array(
x.shape(),
dtype,
std::make_shared<fast::ConvertFP8>(to_stream(s), false),
{x});
}
array to_fp8(array x, StreamOrDevice s) {
if (!issubdtype(x.dtype(), floating)) {
std::ostringstream msg;
msg << "[to_fp8] Only real floating types are supported but "
<< "x.dtype() == " << x.dtype() << ".";
throw std::invalid_argument(msg.str());
}
return array(
x.shape(),
uint8,
std::make_shared<fast::ConvertFP8>(to_stream(s), true),
{x});
}
array gather_qmm(
const array& x,
const array& w,

View File

@@ -1402,6 +1402,12 @@ array dequantize(
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Convert an E4M3 float8 to the given floating point dtype. */
array from_fp8(array x, Dtype dtype, StreamOrDevice s = {});
/** Convert a floating point matrix to E4M3 float8. */
array to_fp8(array x, StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */
array gather_qmm(
const array& x,

View File

@@ -2,7 +2,6 @@
// Required for using M_PI_2 in MSVC.
#define _USE_MATH_DEFINES
#include <cmath>
#include <numeric>
@@ -4030,3 +4029,26 @@ TEST_CASE("test conv_transpose3d with output_padding") {
{1, 2, 4, 4, 1});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test fp8 conversion") {
for (auto t : {float32, float16, bfloat16}) {
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t);
auto in_fp8 = to_fp8(in);
auto out = from_fp8(in_fp8, t);
CHECK(array_equal(out, in).item<bool>());
}
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});
array noisy_in({-1.135, -1.01, 0.0001, 1.01, 1.135, 4.6, 447.0});
auto in_fp8 = to_fp8(noisy_in);
auto out = from_fp8(in_fp8, float32);
CHECK(array_equal(out, in).item<bool>());
// Overflow
in = array({-600.0, 600.0});
in_fp8 = to_fp8(in);
out = from_fp8(in_fp8, float32);
auto expected = array({-448.0f, 448.0f});
CHECK(array_equal(out, expected, true).item<bool>());
}