mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fp8 conversion (#2686)
* add fp8 e4m3 converters * add cuda * default saturate to min/max * fix for older OS * fix no gpu/cpu * fix saturate * fix compile
This commit is contained in:
@@ -88,6 +88,11 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
# Supress warnings: note: parameter passing for argument of type
|
||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||
# 10.1
|
||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/backend/cpu/unary.h"
|
||||
#include "mlx/backend/cpu/unary_ops.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -1102,4 +1105,44 @@ void fast::Quantize::eval_cpu(
|
||||
});
|
||||
}
|
||||
|
||||
void fast::ConvertFP8::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
set_unary_output_data(in, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
to_fp8 = to_fp8_]() mutable {
|
||||
if (to_fp8) {
|
||||
switch (in.dtype()) {
|
||||
case float16:
|
||||
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
|
||||
break;
|
||||
default:
|
||||
unary_op<float, uint8_t>(in, out, detail::ToFP8());
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
|
||||
break;
|
||||
default:
|
||||
unary_op<uint8_t, float>(in, out, detail::FromFP8());
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
@@ -200,6 +201,15 @@ SIMD_DEFAULT_COMPARISONS(<=)
|
||||
SIMD_DEFAULT_COMPARISONS(==)
|
||||
SIMD_DEFAULT_COMPARISONS(!=)
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> clz(Simd<T, N> x) {
|
||||
auto a = *(uint32x4_t*)(&x);
|
||||
auto b = *((uint32x4_t*)(&x) + 1);
|
||||
a = vclzq_u32(a);
|
||||
b = vclzq_u32(b);
|
||||
return asd::make_uint8(a, b);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
|
||||
return asd::atan2(a.value, b.value);
|
||||
|
||||
@@ -171,6 +171,11 @@ DEFAULT_BINARY(&)
|
||||
DEFAULT_BINARY(&&)
|
||||
DEFAULT_BINARY(||)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> clz(Simd<T, 1> x_) {
|
||||
return __builtin_clz(x_.value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
|
||||
@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
|
||||
auto ndim = a.ndim();
|
||||
if (a.flags().contiguous) {
|
||||
auto size = a.data_size();
|
||||
constexpr int N = simd::max_size<T>;
|
||||
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||
while (size >= N) {
|
||||
simd::store(dst, Op{}(simd::load<T, N>(src)));
|
||||
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
|
||||
size -= N;
|
||||
src += N;
|
||||
dst += N;
|
||||
|
||||
@@ -108,4 +108,73 @@ struct Square {
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
template <int N>
|
||||
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
|
||||
return *(Simd<float, N>*)(&x);
|
||||
}
|
||||
template <int N>
|
||||
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
|
||||
return *(Simd<uint32_t, N>*)(&x);
|
||||
}
|
||||
|
||||
struct ToFP8 {
|
||||
template <typename T, int N>
|
||||
Simd<uint8_t, N> operator()(Simd<T, N> f) {
|
||||
uint32_t fp8_max = 1087 << 20;
|
||||
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
|
||||
Simd<uint32_t, N> f_bits;
|
||||
Simd<float, N> f32 = f;
|
||||
f_bits = fp32_to_bits(f32);
|
||||
Simd<uint8_t, N> result = 0u;
|
||||
auto sign = f_bits & 0x80000000;
|
||||
f_bits = f_bits ^ sign;
|
||||
|
||||
auto f_bits_low =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
|
||||
|
||||
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
|
||||
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
|
||||
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
|
||||
|
||||
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
|
||||
result = select(f_bits < (121 << 23), result_low, result_high);
|
||||
|
||||
auto result_sat = Simd<uint8_t, N>(0x7E);
|
||||
result = select(f_bits >= fp8_max, result_sat, result);
|
||||
return result | Simd<uint8_t, N>(sign >> 24);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
uint8_t operator()(T x) {
|
||||
return (*this)(Simd<T, 1>(x)).value;
|
||||
}
|
||||
};
|
||||
|
||||
struct FromFP8 {
|
||||
template <int N>
|
||||
Simd<float, N> operator()(Simd<uint8_t, N> x) {
|
||||
auto w = Simd<uint32_t, N>(x) << 24;
|
||||
auto sign = w & 0x80000000;
|
||||
auto nonsign = w & 0x7FFFFFFF;
|
||||
|
||||
auto renorm_shift = clz(nonsign);
|
||||
renorm_shift = simd::select(
|
||||
renorm_shift > Simd<uint32_t, N>{4},
|
||||
renorm_shift - Simd<uint32_t, N>{4},
|
||||
Simd<uint32_t, N>{0});
|
||||
|
||||
Simd<int32_t, N> inf_nan_mask =
|
||||
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
||||
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
|
||||
auto result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
}
|
||||
float operator()(uint8_t x) {
|
||||
return (*this)(Simd<uint8_t, 1>(x)).value;
|
||||
}
|
||||
};
|
||||
} // namespace mlx::core::detail
|
||||
|
||||
@@ -52,6 +52,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
||||
@@ -170,11 +171,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
# Supress warnings: note: parameter passing for argument of type
|
||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||
# 10.1
|
||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
@@ -334,4 +336,17 @@ struct Tanh {
|
||||
}
|
||||
};
|
||||
|
||||
struct ToFP8 {
|
||||
template <typename T>
|
||||
__device__ uint8_t operator()(T x) {
|
||||
return __nv_fp8_e4m3(x).__x;
|
||||
}
|
||||
};
|
||||
|
||||
struct FromFP8 {
|
||||
__device__ float operator()(uint8_t x) {
|
||||
return float(*(__nv_fp8_e4m3*)(&x));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
19
mlx/backend/cuda/quantized/convert_fp8.cu
Normal file
19
mlx/backend/cuda/quantized/convert_fp8.cu
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
void fast::ConvertFP8::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (to_fp8_) {
|
||||
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
|
||||
} else {
|
||||
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
@@ -108,6 +108,12 @@ constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, LogicalNot>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, ToFP8>) {
|
||||
return std::is_same_v<Out, uint8_t> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, FromFP8>) {
|
||||
return std::is_same_v<In, uint8_t> && is_floating_v<Out>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
index *= N;
|
||||
if (N > 1 && index + N > size) {
|
||||
for (int i = 0; index + i < size; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
out[index + i] = static_cast<U>(Op()(in[index + i]));
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
out[index + i] = Op()(in[index + i]);
|
||||
out[index + i] = static_cast<U>(Op()(in[index + i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -28,11 +28,11 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||
if (N > 1 && offset + N > size) {
|
||||
for (int i = 0; offset + i < size; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
out[offset + i] = Op()(in[offset + i]);
|
||||
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,7 @@ template <
|
||||
IdxT xstride = in_strides[ndim - 1];
|
||||
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
out[out_idx++] = Op()(in[idx]);
|
||||
out[out_idx++] = static_cast<U>(Op()(in[idx]));
|
||||
idx += xstride;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,4 +103,13 @@ instantiate_unary_base_same(Round, complex64, complex64_t)
|
||||
instantiate_unary_base(Real, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
|
||||
|
||||
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on
|
||||
instantiate_unary_all_same(LogicalNot, bool_, bool)
|
||||
|
||||
instantiate_unary_all(ToFP8, float16, uint8, float16_t, uint8_t)
|
||||
instantiate_unary_all(ToFP8, bfloat16, uint8, bfloat16_t, uint8_t)
|
||||
instantiate_unary_all(ToFP8, float32, uint8, float, uint8_t)
|
||||
instantiate_unary_all(FromFP8, uint8, float16, uint8_t, float16_t)
|
||||
instantiate_unary_all(FromFP8, uint8, bfloat16, uint8_t, bfloat16_t)
|
||||
instantiate_unary_all(FromFP8, uint8, float32, uint8_t, float)
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -225,8 +225,7 @@ struct Floor {
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
float operator()(complex64_t x) {
|
||||
return x.imag;
|
||||
};
|
||||
};
|
||||
@@ -290,8 +289,7 @@ struct Negative {
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
float operator()(complex64_t x) {
|
||||
return x.real;
|
||||
};
|
||||
};
|
||||
@@ -440,3 +438,64 @@ complex64_t ArcTan::operator()(complex64_t x) {
|
||||
auto ix = i * x;
|
||||
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
|
||||
};
|
||||
|
||||
inline float fp32_from_bits(uint32_t bits) {
|
||||
return *(reinterpret_cast<thread float*>(&bits));
|
||||
}
|
||||
inline float fp32_to_bits(float x) {
|
||||
return *(reinterpret_cast<thread uint32_t*>(&x));
|
||||
}
|
||||
|
||||
struct ToFP8 {
|
||||
template <typename T>
|
||||
uint8_t operator()(T f) {
|
||||
// From PyTorch
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
|
||||
uint32_t fp8_max = 1087 << 20;
|
||||
uint32_t denorm_mask = 141 << 23;
|
||||
uint32_t f_bits = fp32_to_bits(static_cast<float>(f));
|
||||
uint8_t result = 0u;
|
||||
uint32_t sign = f_bits & 0x80000000;
|
||||
f_bits ^= sign;
|
||||
if (f_bits >= fp8_max) {
|
||||
// Default behavior saturates to min/max
|
||||
result = 0x7E;
|
||||
} else {
|
||||
if (f_bits < (121 << 23)) {
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
||||
f_bits += mant_odd;
|
||||
result = static_cast<uint8_t>(f_bits >> 20);
|
||||
}
|
||||
}
|
||||
result |= static_cast<uint8_t>(sign >> 24);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct FromFP8 {
|
||||
float operator()(uint8_t x) {
|
||||
// From PyTorch:
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
|
||||
uint32_t w = static_cast<uint32_t>(x) << 24;
|
||||
uint32_t sign = w & 0x80000000;
|
||||
uint32_t nonsign = w & 0x7FFFFFFF;
|
||||
|
||||
uint32_t renorm_shift = metal::clz(nonsign);
|
||||
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||
|
||||
int32_t inf_nan_mask =
|
||||
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
||||
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
|
||||
uint32_t result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -1108,4 +1109,12 @@ void fast::Quantize::eval_gpu(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fast::ConvertFP8::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
unary_op_gpu(inputs, out, name(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -144,17 +144,7 @@ UNARY_GPU(Tan)
|
||||
UNARY_GPU(Tanh)
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
}
|
||||
unary_op_gpu(inputs, out, name());
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
@@ -130,6 +130,7 @@ NO_CPU(View)
|
||||
|
||||
namespace fast {
|
||||
NO_CPU_MULTI(Quantize)
|
||||
NO_CPU_MULTI(ConvertFP8)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
|
||||
@@ -154,6 +154,7 @@ NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(ConvertFP8)
|
||||
NO_GPU_MULTI(Quantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
@@ -843,4 +843,9 @@ std::vector<Shape> Quantize::output_shapes(const std::vector<array>& inputs) {
|
||||
}
|
||||
}
|
||||
|
||||
bool ConvertFP8::is_equivalent(const Primitive& other) const {
|
||||
const ConvertFP8& a_other = static_cast<const ConvertFP8&>(other);
|
||||
return to_fp8_ == a_other.to_fp8_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
||||
@@ -250,6 +250,35 @@ class ScaledDotProductAttention : public Custom {
|
||||
bool has_sinks_;
|
||||
};
|
||||
|
||||
class ConvertFP8 : public Primitive {
|
||||
public:
|
||||
explicit ConvertFP8(Stream stream, bool to_fp8)
|
||||
: Primitive(stream), to_fp8_(to_fp8) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
const char* name() const override {
|
||||
if (to_fp8_) {
|
||||
return "ToFP8";
|
||||
} else {
|
||||
return "FromFP8";
|
||||
}
|
||||
}
|
||||
bool state() const {
|
||||
return to_fp8_;
|
||||
};
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
bool to_fp8_;
|
||||
};
|
||||
|
||||
class Quantize : public Custom {
|
||||
public:
|
||||
explicit Quantize(
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/io.h"
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -103,92 +102,6 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
|
||||
}
|
||||
}
|
||||
|
||||
array f8_e4m3_to_float(array x, Dtype dtype, StreamOrDevice s) {
|
||||
if (to_stream(s).device == Device::gpu) {
|
||||
// From PyTorch:
|
||||
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
|
||||
std::string source = R"(
|
||||
uint elem = thread_position_in_grid.x;
|
||||
uint8_t val = x[elem];
|
||||
|
||||
const uint32_t w = (uint32_t)val << 24;
|
||||
const uint32_t sign = w & 0x80000000;
|
||||
const uint32_t nonsign = w & 0x7FFFFFFF;
|
||||
|
||||
uint32_t renorm_shift = metal::clz(nonsign);
|
||||
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||
|
||||
const int32_t inf_nan_mask =
|
||||
((int32_t)(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
||||
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
|
||||
uint32_t result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
|
||||
float out = *(reinterpret_cast<thread float*>(&result));
|
||||
y[elem] = static_cast<T>(out);
|
||||
)";
|
||||
auto kernel = fast::metal_kernel("f8_e4m3", {"x"}, {"y"}, source);
|
||||
auto outputs = kernel(
|
||||
{x},
|
||||
{x.shape()},
|
||||
{dtype},
|
||||
{x.size(), 1, 1},
|
||||
{256, 1, 1},
|
||||
{{"T", dtype}},
|
||||
std::nullopt,
|
||||
false,
|
||||
s);
|
||||
return outputs[0];
|
||||
} else {
|
||||
auto w = left_shift(astype(x, uint32, s), array({24}, uint32), s);
|
||||
auto sign = bitwise_and(w, array({0x80000000}, uint32), s);
|
||||
auto nonsign = bitwise_and(w, array({0x7FFFFFFF}, uint32), s);
|
||||
|
||||
// Emulate a clz op with a lookup table
|
||||
auto clz_table =
|
||||
array({28, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, uint32);
|
||||
auto renorm_shift = take(clz_table, bitwise_and(x, array({0xf}), s), s);
|
||||
renorm_shift = where(
|
||||
greater(
|
||||
bitwise_and(x, array({0x70}, uint32), s), array({0}, uint32), s),
|
||||
array({0}, uint32),
|
||||
renorm_shift,
|
||||
s);
|
||||
auto inf_nan_mask = bitwise_and(
|
||||
right_shift(
|
||||
astype(add(nonsign, array(0x01000000, int32), s), int32, s),
|
||||
array({8}, int32),
|
||||
s),
|
||||
array({0x7F800000}, int32),
|
||||
s);
|
||||
auto zero_mask = right_shift(
|
||||
astype(subtract(nonsign, array({1}, uint32), s), int32, s),
|
||||
array({31}, int32),
|
||||
s);
|
||||
zero_mask = astype(zero_mask, uint32, s);
|
||||
inf_nan_mask = astype(inf_nan_mask, uint32, s);
|
||||
auto result =
|
||||
add(right_shift(
|
||||
left_shift(nonsign, renorm_shift, s), array({4}, uint32), s),
|
||||
left_shift(
|
||||
subtract(array({0x78}, uint32), renorm_shift, s),
|
||||
array({23}, uint32),
|
||||
s),
|
||||
s);
|
||||
result = bitwise_or(
|
||||
sign,
|
||||
bitwise_and(
|
||||
bitwise_or(result, inf_nan_mask, s),
|
||||
bitwise_invert(zero_mask, s),
|
||||
s),
|
||||
s);
|
||||
result = astype(view(result, float32, s), dtype, s);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/** Load array from reader in safetensor format */
|
||||
SafetensorsLoad load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
@@ -244,7 +157,7 @@ SafetensorsLoad load_safetensors(
|
||||
stream, in_stream, offset + data_offsets.at(0), false),
|
||||
std::vector<array>{});
|
||||
if (dtype == ST_F8_E4M3) {
|
||||
loaded_array = f8_e4m3_to_float(loaded_array, bfloat16, s);
|
||||
loaded_array = from_fp8(loaded_array, bfloat16, s);
|
||||
}
|
||||
res.insert({item.key(), loaded_array});
|
||||
}
|
||||
|
||||
34
mlx/ops.cpp
34
mlx/ops.cpp
@@ -4506,6 +4506,40 @@ array dequantize(
|
||||
}
|
||||
}
|
||||
|
||||
array from_fp8(array x, Dtype dtype, StreamOrDevice s) {
|
||||
if (x.dtype() != uint8) {
|
||||
std::ostringstream msg;
|
||||
msg << "[from_fp8] Input must have type uint8 but "
|
||||
<< "x.dtype() == " << x.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[from_fp8] Only real floating types are supported but "
|
||||
<< "dtype == " << dtype << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return array(
|
||||
x.shape(),
|
||||
dtype,
|
||||
std::make_shared<fast::ConvertFP8>(to_stream(s), false),
|
||||
{x});
|
||||
}
|
||||
|
||||
array to_fp8(array x, StreamOrDevice s) {
|
||||
if (!issubdtype(x.dtype(), floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[to_fp8] Only real floating types are supported but "
|
||||
<< "x.dtype() == " << x.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return array(
|
||||
x.shape(),
|
||||
uint8,
|
||||
std::make_shared<fast::ConvertFP8>(to_stream(s), true),
|
||||
{x});
|
||||
}
|
||||
|
||||
array gather_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
|
||||
@@ -1402,6 +1402,12 @@ array dequantize(
|
||||
const std::string& mode = "affine",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Convert an E4M3 float8 to the given floating point dtype. */
|
||||
array from_fp8(array x, Dtype dtype, StreamOrDevice s = {});
|
||||
|
||||
/** Convert a floating point matrix to E4M3 float8. */
|
||||
array to_fp8(array x, StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix products with matrix-level gather. */
|
||||
array gather_qmm(
|
||||
const array& x,
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
// Required for using M_PI_2 in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
@@ -4030,3 +4029,26 @@ TEST_CASE("test conv_transpose3d with output_padding") {
|
||||
{1, 2, 4, 4, 1});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test fp8 conversion") {
|
||||
for (auto t : {float32, float16, bfloat16}) {
|
||||
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t);
|
||||
auto in_fp8 = to_fp8(in);
|
||||
auto out = from_fp8(in_fp8, t);
|
||||
CHECK(array_equal(out, in).item<bool>());
|
||||
}
|
||||
|
||||
array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});
|
||||
array noisy_in({-1.135, -1.01, 0.0001, 1.01, 1.135, 4.6, 447.0});
|
||||
auto in_fp8 = to_fp8(noisy_in);
|
||||
auto out = from_fp8(in_fp8, float32);
|
||||
CHECK(array_equal(out, in).item<bool>());
|
||||
|
||||
// Overflow
|
||||
in = array({-600.0, 600.0});
|
||||
in_fp8 = to_fp8(in);
|
||||
out = from_fp8(in_fp8, float32);
|
||||
|
||||
auto expected = array({-448.0f, 448.0f});
|
||||
CHECK(array_equal(out, expected, true).item<bool>());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user