Add quantize/dequantize for mxfp8 and nvfp4 (#2688)

* Add quantize/dequantize slow path for mxfp8 and nvfp4

* fast cuda kernel for mx/nv quantization

* fallback for cuda < 12.8 (#2697)

* format (#2700)

* fix (#2701)

* metal kernels

* docs

* fix jit

* add default bits and group sizes

* improve quant docs

* fix output type of mxfp4 matmuls
This commit is contained in:
Awni Hannun
2025-10-28 16:23:12 -07:00
committed by GitHub
parent 460691a0e8
commit ec72b44417
25 changed files with 1400 additions and 588 deletions

View File

@@ -120,7 +120,7 @@ Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 1087 << 20;
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;