diff --git a/CMakeLists.txt b/CMakeLists.txt index 84d4198ba..0dbc0b51b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,6 +88,11 @@ cmake_policy(SET CMP0135 NEW) add_library(mlx) +# Supress warnings: note: parameter passing for argument of type +# ‘std::pair’ when C++17 is enabled changed to match C++14 in GCC +# 10.1 +target_compile_options(mlx PRIVATE -Wno-psabi) + if(MLX_BUILD_CUDA) enable_language(CUDA) endif() diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index de50cdb81..75a8e6233 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -1,8 +1,11 @@ // Copyright © 2023 Apple Inc. +#include "mlx/backend/common/unary.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" +#include "mlx/backend/cpu/unary.h" +#include "mlx/backend/cpu/unary_ops.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -1102,4 +1105,44 @@ void fast::Quantize::eval_cpu( }); } +void fast::ConvertFP8::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + auto& in = inputs[0]; + auto& out = outputs[0]; + set_unary_output_data(in, out); + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.dispatch([in = array::unsafe_weak_copy(in), + out = array::unsafe_weak_copy(out), + to_fp8 = to_fp8_]() mutable { + if (to_fp8) { + switch (in.dtype()) { + case float16: + unary_op(in, out, detail::ToFP8()); + break; + case bfloat16: + unary_op(in, out, detail::ToFP8()); + break; + default: + unary_op(in, out, detail::ToFP8()); + break; + } + } else { + switch (out.dtype()) { + case float16: + unary_op(in, out, detail::FromFP8()); + break; + case bfloat16: + unary_op(in, out, detail::FromFP8()); + break; + default: + unary_op(in, out, detail::FromFP8()); + break; + } + } + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h index 914831055..c89a104a0 100644 --- a/mlx/backend/cpu/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -200,6 +201,15 @@ SIMD_DEFAULT_COMPARISONS(<=) SIMD_DEFAULT_COMPARISONS(==) SIMD_DEFAULT_COMPARISONS(!=) +template +Simd clz(Simd x) { + auto a = *(uint32x4_t*)(&x); + auto b = *((uint32x4_t*)(&x) + 1); + a = vclzq_u32(a); + b = vclzq_u32(b); + return asd::make_uint8(a, b); +} + template Simd atan2(Simd a, Simd b) { return asd::atan2(a.value, b.value); diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index 17cd35b9a..fc9fbbff5 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -171,6 +171,11 @@ DEFAULT_BINARY(&) DEFAULT_BINARY(&&) DEFAULT_BINARY(||) +template +Simd clz(Simd x_) { + return __builtin_clz(x_.value); +} + template Simd remainder(Simd a_, Simd b_) { T a = a_.value; diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index 14c1dd479..4fab6a754 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) { auto ndim = a.ndim(); if (a.flags().contiguous) { auto size = a.data_size(); - constexpr int N = simd::max_size; + constexpr int N = std::min(simd::max_size, simd::max_size); while (size >= N) { - simd::store(dst, Op{}(simd::load(src))); + simd::store(dst, simd::Simd(Op{}(simd::load(src)))); size -= N; src += N; dst += N; diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index 255e28a19..20d9c60f6 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -108,4 +108,73 @@ struct Square { SINGLE() }; +template +Simd fp32_from_bits(Simd x) { + return *(Simd*)(&x); +} +template +Simd fp32_to_bits(Simd x) { + return *(Simd*)(&x); +} + +struct ToFP8 { + template + Simd operator()(Simd f) { + uint32_t fp8_max = 1087 << 20; + auto denorm_mask = Simd(141 << 23); + Simd f_bits; + Simd f32 = f; + f_bits = fp32_to_bits(f32); + Simd result = 0u; + auto sign = f_bits & 0x80000000; + f_bits = f_bits ^ sign; + + auto f_bits_low = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + auto result_low = Simd(f_bits_low - denorm_mask); + + auto mant_odd = Simd((f_bits >> 20) & 1); + auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF); + f_bits_high = f_bits_high + Simd(mant_odd); + + auto result_high = Simd(f_bits_high >> 20); + result = select(f_bits < (121 << 23), result_low, result_high); + + auto result_sat = Simd(0x7E); + result = select(f_bits >= fp8_max, result_sat, result); + return result | Simd(sign >> 24); + } + + template + uint8_t operator()(T x) { + return (*this)(Simd(x)).value; + } +}; + +struct FromFP8 { + template + Simd operator()(Simd x) { + auto w = Simd(x) << 24; + auto sign = w & 0x80000000; + auto nonsign = w & 0x7FFFFFFF; + + auto renorm_shift = clz(nonsign); + renorm_shift = simd::select( + renorm_shift > Simd{4}, + renorm_shift - Simd{4}, + Simd{0}); + + Simd inf_nan_mask = + (Simd(nonsign + 0x01000000) >> 8) & 0x7F800000; + auto zero_mask = Simd(nonsign - 1) >> 31; + auto result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); + } + float operator()(uint8_t x) { + return (*this)(Simd(x)).value; + } +}; } // namespace mlx::core::detail diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 5c4bf7115..eabee94f2 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -52,6 +52,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) @@ -170,11 +171,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all) # Suppress nvcc warnings on MLX headers. target_compile_options(mlx PRIVATE $<$:-Xcudafe --diag_suppress=997>) -# Supress warnings: note: parameter passing for argument of type -# ‘std::pair’ when C++17 is enabled changed to match C++14 in GCC -# 10.1 -target_compile_options(mlx PRIVATE -Wno-psabi) - # Install CCCL headers for JIT. install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index fcd083f2f..61c587e27 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,6 +2,8 @@ #pragma once +#include + #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" @@ -334,4 +336,17 @@ struct Tanh { } }; +struct ToFP8 { + template + __device__ uint8_t operator()(T x) { + return __nv_fp8_e4m3(x).__x; + } +}; + +struct FromFP8 { + __device__ float operator()(uint8_t x) { + return float(*(__nv_fp8_e4m3*)(&x)); + } +}; + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/quantized/convert_fp8.cu b/mlx/backend/cuda/quantized/convert_fp8.cu new file mode 100644 index 000000000..c0be2e381 --- /dev/null +++ b/mlx/backend/cuda/quantized/convert_fp8.cu @@ -0,0 +1,19 @@ +// Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/unary/unary.cuh" +#include "mlx/fast_primitives.h" + +namespace mlx::core { +void fast::ConvertFP8::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("ConvertFP8::eval_gpu"); + auto& in = inputs[0]; + auto& out = outputs[0]; + auto& s = out.primitive().stream(); + if (to_fp8_) { + unary_op_gpu(inputs, out, name(), s); + } else { + unary_op_gpu(inputs, out, name(), s); + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/unary.cuh b/mlx/backend/cuda/unary/unary.cuh index a20e119ca..8f4a02d50 100644 --- a/mlx/backend/cuda/unary/unary.cuh +++ b/mlx/backend/cuda/unary/unary.cuh @@ -108,6 +108,12 @@ constexpr bool supports_unary_op() { if (std::is_same_v) { return std::is_same_v && std::is_same_v; } + if (std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v) { + return std::is_same_v && is_floating_v; + } return false; } diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 649ba7f2c..db7be3d41 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -9,11 +9,11 @@ template ::n> index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { - out[index + i] = Op()(in[index + i]); + out[index + i] = static_cast(Op()(in[index + i])); } } else { for (int i = 0; i < N; ++i) { - out[index + i] = Op()(in[index + i]); + out[index + i] = static_cast(Op()(in[index + i])); } } } @@ -28,11 +28,11 @@ template ::n> int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { - out[offset + i] = Op()(in[offset + i]); + out[offset + i] = static_cast(Op()(in[offset + i])); } } else { for (int i = 0; i < N; ++i) { - out[offset + i] = Op()(in[offset + i]); + out[offset + i] = static_cast(Op()(in[offset + i])); } } } @@ -57,7 +57,7 @@ template < IdxT xstride = in_strides[ndim - 1]; IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { - out[out_idx++] = Op()(in[idx]); + out[out_idx++] = static_cast(Op()(in[idx])); idx += xstride; } } diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 160ef4af1..54a0f566c 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -103,4 +103,13 @@ instantiate_unary_base_same(Round, complex64, complex64_t) instantiate_unary_base(Real, complex64, float32, complex64_t, float) instantiate_unary_base(Imag, complex64, float32, complex64_t, float) -instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on +instantiate_unary_all_same(LogicalNot, bool_, bool) + +instantiate_unary_all(ToFP8, float16, uint8, float16_t, uint8_t) +instantiate_unary_all(ToFP8, bfloat16, uint8, bfloat16_t, uint8_t) +instantiate_unary_all(ToFP8, float32, uint8, float, uint8_t) +instantiate_unary_all(FromFP8, uint8, float16, uint8_t, float16_t) +instantiate_unary_all(FromFP8, uint8, bfloat16, uint8_t, bfloat16_t) +instantiate_unary_all(FromFP8, uint8, float32, uint8_t, float) + + // clang-format on diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 44d43cee8..423c07f66 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -225,8 +225,7 @@ struct Floor { }; struct Imag { - template - T operator()(T x) { + float operator()(complex64_t x) { return x.imag; }; }; @@ -290,8 +289,7 @@ struct Negative { }; struct Real { - template - T operator()(T x) { + float operator()(complex64_t x) { return x.real; }; }; @@ -440,3 +438,64 @@ complex64_t ArcTan::operator()(complex64_t x) { auto ix = i * x; return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); }; + +inline float fp32_from_bits(uint32_t bits) { + return *(reinterpret_cast(&bits)); +} +inline float fp32_to_bits(float x) { + return *(reinterpret_cast(&x)); +} + +struct ToFP8 { + template + uint8_t operator()(T f) { + // From PyTorch + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 + uint32_t fp8_max = 1087 << 20; + uint32_t denorm_mask = 141 << 23; + uint32_t f_bits = fp32_to_bits(static_cast(f)); + uint8_t result = 0u; + uint32_t sign = f_bits & 0x80000000; + f_bits ^= sign; + if (f_bits >= fp8_max) { + // Default behavior saturates to min/max + result = 0x7E; + } else { + if (f_bits < (121 << 23)) { + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + f_bits += mant_odd; + result = static_cast(f_bits >> 20); + } + } + result |= static_cast(sign >> 24); + return result; + } +}; + +struct FromFP8 { + float operator()(uint8_t x) { + // From PyTorch: + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46 + uint32_t w = static_cast(x) << 24; + uint32_t sign = w & 0x80000000; + uint32_t nonsign = w & 0x7FFFFFFF; + + uint32_t renorm_shift = metal::clz(nonsign); + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + + int32_t inf_nan_mask = + (static_cast(nonsign + 0x01000000) >> 8) & 0x7F800000; + int32_t zero_mask = static_cast(nonsign - 1) >> 31; + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); + } +}; diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 452bc4faa..328669a92 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/reduce.h" +#include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" @@ -1108,4 +1109,12 @@ void fast::Quantize::eval_gpu( compute_encoder.dispatch_threads(grid_dims, group_dims); } +void fast::ConvertFP8::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& in = inputs[0]; + auto& out = outputs[0]; + unary_op_gpu(inputs, out, name(), stream()); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 8b983a45a..833b23f63 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -144,17 +144,7 @@ UNARY_GPU(Tan) UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { - switch (base_) { - case Base::e: - unary_op_gpu(inputs, out, name()); - break; - case Base::two: - unary_op_gpu(inputs, out, name()); - break; - case Base::ten: - unary_op_gpu(inputs, out, name()); - break; - } + unary_op_gpu(inputs, out, name()); } void Round::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index dba82c6dc..4d373bd1a 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -130,6 +130,7 @@ NO_CPU(View) namespace fast { NO_CPU_MULTI(Quantize) +NO_CPU_MULTI(ConvertFP8) } // namespace fast namespace distributed { diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 22a0c8acc..a57df046c 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -154,6 +154,7 @@ NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) +NO_GPU_MULTI(ConvertFP8) NO_GPU_MULTI(Quantize) NO_GPU_MULTI(CustomKernel) } // namespace fast diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 0f34aec93..82c419264 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -843,4 +843,9 @@ std::vector Quantize::output_shapes(const std::vector& inputs) { } } +bool ConvertFP8::is_equivalent(const Primitive& other) const { + const ConvertFP8& a_other = static_cast(other); + return to_fp8_ == a_other.to_fp8_; +} + } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index a8000485a..d2b4b5611 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -250,6 +250,35 @@ class ScaledDotProductAttention : public Custom { bool has_sinks_; }; +class ConvertFP8 : public Primitive { + public: + explicit ConvertFP8(Stream stream, bool to_fp8) + : Primitive(stream), to_fp8_(to_fp8) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + const char* name() const override { + if (to_fp8_) { + return "ToFP8"; + } else { + return "FromFP8"; + } + } + bool state() const { + return to_fp8_; + }; + + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + bool to_fp8_; +}; + class Quantize : public Custom { public: explicit Quantize( diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index 7ee974186..d9b9e9e40 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -4,7 +4,6 @@ #include #include -#include "mlx/fast.h" #include "mlx/io.h" #include "mlx/io/load.h" #include "mlx/ops.h" @@ -103,92 +102,6 @@ Dtype dtype_from_safetensor_str(std::string_view str) { } } -array f8_e4m3_to_float(array x, Dtype dtype, StreamOrDevice s) { - if (to_stream(s).device == Device::gpu) { - // From PyTorch: - // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46 - std::string source = R"( - uint elem = thread_position_in_grid.x; - uint8_t val = x[elem]; - - const uint32_t w = (uint32_t)val << 24; - const uint32_t sign = w & 0x80000000; - const uint32_t nonsign = w & 0x7FFFFFFF; - - uint32_t renorm_shift = metal::clz(nonsign); - renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; - - const int32_t inf_nan_mask = - ((int32_t)(nonsign + 0x01000000) >> 8) & 0x7F800000; - const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; - uint32_t result = sign | - ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | - inf_nan_mask) & - ~zero_mask); - - float out = *(reinterpret_cast(&result)); - y[elem] = static_cast(out); - )"; - auto kernel = fast::metal_kernel("f8_e4m3", {"x"}, {"y"}, source); - auto outputs = kernel( - {x}, - {x.shape()}, - {dtype}, - {x.size(), 1, 1}, - {256, 1, 1}, - {{"T", dtype}}, - std::nullopt, - false, - s); - return outputs[0]; - } else { - auto w = left_shift(astype(x, uint32, s), array({24}, uint32), s); - auto sign = bitwise_and(w, array({0x80000000}, uint32), s); - auto nonsign = bitwise_and(w, array({0x7FFFFFFF}, uint32), s); - - // Emulate a clz op with a lookup table - auto clz_table = - array({28, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, uint32); - auto renorm_shift = take(clz_table, bitwise_and(x, array({0xf}), s), s); - renorm_shift = where( - greater( - bitwise_and(x, array({0x70}, uint32), s), array({0}, uint32), s), - array({0}, uint32), - renorm_shift, - s); - auto inf_nan_mask = bitwise_and( - right_shift( - astype(add(nonsign, array(0x01000000, int32), s), int32, s), - array({8}, int32), - s), - array({0x7F800000}, int32), - s); - auto zero_mask = right_shift( - astype(subtract(nonsign, array({1}, uint32), s), int32, s), - array({31}, int32), - s); - zero_mask = astype(zero_mask, uint32, s); - inf_nan_mask = astype(inf_nan_mask, uint32, s); - auto result = - add(right_shift( - left_shift(nonsign, renorm_shift, s), array({4}, uint32), s), - left_shift( - subtract(array({0x78}, uint32), renorm_shift, s), - array({23}, uint32), - s), - s); - result = bitwise_or( - sign, - bitwise_and( - bitwise_or(result, inf_nan_mask, s), - bitwise_invert(zero_mask, s), - s), - s); - result = astype(view(result, float32, s), dtype, s); - return result; - } -} - /** Load array from reader in safetensor format */ SafetensorsLoad load_safetensors( std::shared_ptr in_stream, @@ -244,7 +157,7 @@ SafetensorsLoad load_safetensors( stream, in_stream, offset + data_offsets.at(0), false), std::vector{}); if (dtype == ST_F8_E4M3) { - loaded_array = f8_e4m3_to_float(loaded_array, bfloat16, s); + loaded_array = from_fp8(loaded_array, bfloat16, s); } res.insert({item.key(), loaded_array}); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 30e934f82..879ef4fd5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4506,6 +4506,40 @@ array dequantize( } } +array from_fp8(array x, Dtype dtype, StreamOrDevice s) { + if (x.dtype() != uint8) { + std::ostringstream msg; + msg << "[from_fp8] Input must have type uint8 but " + << "x.dtype() == " << x.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + if (!issubdtype(dtype, floating)) { + std::ostringstream msg; + msg << "[from_fp8] Only real floating types are supported but " + << "dtype == " << dtype << "."; + throw std::invalid_argument(msg.str()); + } + return array( + x.shape(), + dtype, + std::make_shared(to_stream(s), false), + {x}); +} + +array to_fp8(array x, StreamOrDevice s) { + if (!issubdtype(x.dtype(), floating)) { + std::ostringstream msg; + msg << "[to_fp8] Only real floating types are supported but " + << "x.dtype() == " << x.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + return array( + x.shape(), + uint8, + std::make_shared(to_stream(s), true), + {x}); +} + array gather_qmm( const array& x, const array& w, diff --git a/mlx/ops.h b/mlx/ops.h index bfe6eff16..312caac6d 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1402,6 +1402,12 @@ array dequantize( const std::string& mode = "affine", StreamOrDevice s = {}); +/** Convert an E4M3 float8 to the given floating point dtype. */ +array from_fp8(array x, Dtype dtype, StreamOrDevice s = {}); + +/** Convert a floating point matrix to E4M3 float8. */ +array to_fp8(array x, StreamOrDevice s = {}); + /** Compute matrix products with matrix-level gather. */ array gather_qmm( const array& x, diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 2e8bbd692..c473b59c3 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2,7 +2,6 @@ // Required for using M_PI_2 in MSVC. #define _USE_MATH_DEFINES - #include #include @@ -4030,3 +4029,26 @@ TEST_CASE("test conv_transpose3d with output_padding") { {1, 2, 4, 4, 1}); CHECK(array_equal(out, expected).item()); } + +TEST_CASE("test fp8 conversion") { + for (auto t : {float32, float16, bfloat16}) { + array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t); + auto in_fp8 = to_fp8(in); + auto out = from_fp8(in_fp8, t); + CHECK(array_equal(out, in).item()); + } + + array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}); + array noisy_in({-1.135, -1.01, 0.0001, 1.01, 1.135, 4.6, 447.0}); + auto in_fp8 = to_fp8(noisy_in); + auto out = from_fp8(in_fp8, float32); + CHECK(array_equal(out, in).item()); + + // Overflow + in = array({-600.0, 600.0}); + in_fp8 = to_fp8(in); + out = from_fp8(in_fp8, float32); + + auto expected = array({-448.0f, 448.0f}); + CHECK(array_equal(out, expected, true).item()); +}