Add quantize/dequantize for mxfp8 and nvfp4 (#2688)

* Add quantize/dequantize slow path for mxfp8 and nvfp4

* fast cuda kernel for mx/nv quantization

* fallback for cuda < 12.8 (#2697)

* format (#2700)

* fix (#2701)

* metal kernels

* docs

* fix jit

* add default bits and group sizes

* improve quant docs

* fix output type of mxfp4 matmuls
This commit is contained in:
Awni Hannun
2025-10-28 16:23:12 -07:00
committed by GitHub
parent 460691a0e8
commit ec72b44417
25 changed files with 1400 additions and 588 deletions

View File

@@ -306,7 +306,7 @@ void affine_dequantize(
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_float_types(w.dtype(), "affine_dequantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;