Add quantize/dequantize slow path for mxfp8 and nvfp4

This commit is contained in:
Awni Hannun
2025-10-20 16:53:03 -07:00
parent 5d7efafe92
commit 8afc36cb87
5 changed files with 197 additions and 78 deletions

View File

@@ -4268,10 +4268,11 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dequantize the matrix ``w`` using quantization parameters.
@@ -4284,6 +4285,10 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
dtype (Dtype, optional): The data type of the dequantized output. If
``None`` the return type is inferred from the scales and biases
when possible and otherwise defaults to ``bfloat16``.
Default: ``None``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns: