mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add default bits and group sizes
This commit is contained in:
@@ -4194,13 +4194,13 @@ void init_ops(nb::module_& m) {
|
||||
"scales"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"group_size"_a = nb::none(),
|
||||
"bits"_a = nb::none(),
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@@ -4217,9 +4217,11 @@ void init_ops(nb::module_& m) {
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
group_size (int, optional): The size of the group in ``w`` that
|
||||
shares a scale and bias. Default: ``64``.
|
||||
shares a scale and bias. If unspecified, a default is chosen based
|
||||
on the mode. Default: ``None``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
``w``. If unspecified, a default is chosen based on the mode.
|
||||
Default: ``None``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
@@ -4229,13 +4231,13 @@ void init_ops(nb::module_& m) {
|
||||
"quantize",
|
||||
&mx::quantize,
|
||||
nb::arg(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"group_size"_a = nb::none(),
|
||||
"bits"_a = nb::none(),
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@@ -4268,6 +4270,16 @@ void init_ops(nb::module_& m) {
|
||||
* biases (array): The quantization biases (returned for ``mode=="affine"``).
|
||||
|
||||
Notes:
|
||||
|
||||
====== ====================== ========================== ============= ========
|
||||
mode group size bits scale type has bias
|
||||
====== ====================== ========================== ============= ========
|
||||
affine 32, 64 (default), 128 2, 3, 4 (default), 5, 6, 8 same as input yes
|
||||
mxfp4 32 4 e8m0 no
|
||||
mxfp8 32 4 e8m0 no
|
||||
nvfp4 16 4 e4m3 no
|
||||
====== ====================== ========================== ============= ========
|
||||
|
||||
The ``"affine"`` mode quantizes groups of :math:`g` consecutive
|
||||
elements in a row of ``w``. For each group the quantized
|
||||
representation of each element :math:`\hat{w_i}` is computed as follows:
|
||||
@@ -4310,14 +4322,14 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"group_size"_a = nb::none(),
|
||||
"bits"_a = nb::none(),
|
||||
"mode"_a = "affine",
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
@@ -4361,14 +4373,14 @@ void init_ops(nb::module_& m) {
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"group_size"_a = nb::none(),
|
||||
"bits"_a = nb::none(),
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"sorted_indices"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
|
||||
@@ -55,17 +55,17 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, group_size=32, mode="mxfp4")
|
||||
mx.quantize(w, bits=3, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
|
||||
mx.quantize(w, group_size=64, mode="mxfp4")
|
||||
|
||||
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
|
||||
w_q, scales = mx.quantize(w, mode="mxfp4")
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
|
||||
mx.dequantize(w_q, scales, bits=3, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
|
||||
mx.dequantize(w_q, scales, group_size=64, mode="mxfp4")
|
||||
|
||||
# Invalid output type
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -73,13 +73,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
w_q, scales, group_size=32, bits=4, mode="mxfp4", dtype=mx.int32
|
||||
)
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="mxfp4")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
w_q, scales = mx.quantize(a, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="mxfp4")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_mxfp8_quantize_dequantize(self):
|
||||
@@ -88,26 +88,26 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, group_size=32, mode="mxfp8")
|
||||
mx.quantize(w, bits=3, mode="mxfp8")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=32, bits=7, mode="mxfp8")
|
||||
w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8")
|
||||
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp8")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=8, group_size=16, mode="mxfp8")
|
||||
mx.dequantize(w_q, scales, group_size=16, mode="mxfp8")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8")
|
||||
mx.dequantize(w_q, scales, bits=4, mode="mxfp8")
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="mxfp8")
|
||||
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-1))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, group_size=32, bits=8, mode="mxfp8")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
|
||||
w_q, scales = mx.quantize(a, mode="mxfp8")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="mxfp8")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_nvfp4_quantize_dequantize(self):
|
||||
@@ -138,26 +138,26 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, group_size=16, mode="nvfp4")
|
||||
mx.quantize(w, bits=3, mode="nvfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=64, bits=4, mode="nvfp4")
|
||||
mx.quantize(w, group_size=64, mode="nvfp4")
|
||||
|
||||
w_q, scales = mx.quantize(w, group_size=16, bits=4, mode="nvfp4")
|
||||
w_q, scales = mx.quantize(w, mode="nvfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=4, group_size=32, mode="nvfp4")
|
||||
mx.dequantize(w_q, scales, mode="nvfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="nvfp4")
|
||||
mx.dequantize(w_q, scales, group_size=32, mode="nvfp4")
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="nvfp4")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, group_size=16, bits=4, mode="nvfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4")
|
||||
w_q, scales = mx.quantize(a, mode="nvfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, mode="nvfp4")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_qmm(self):
|
||||
|
||||
Reference in New Issue
Block a user