mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add quantize/dequantize slow path for mxfp8 and nvfp4
This commit is contained in:
@@ -4268,10 +4268,11 @@ void init_ops(nb::module_& m) {
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
@@ -4284,6 +4285,10 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
dtype (Dtype, optional): The data type of the dequantized output. If
|
||||
``None`` the return type is inferred from the scales and biases
|
||||
when possible and otherwise defaults to ``bfloat16``.
|
||||
Default: ``None``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
|
||||
Reference in New Issue
Block a user