mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
mxfp4 quantize/dequantize + start of optional biases
This commit is contained in:
@@ -98,9 +98,11 @@ class QuantizedEmbedding(Module):
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / dims)
|
||||
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
weight, group_size, bits, mode=mode
|
||||
)
|
||||
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
self.scales = scales_biases
|
||||
self.num_embeddings = num_embeddings
|
||||
self.dims = dims
|
||||
|
||||
@@ -108,10 +110,11 @@ class QuantizedEmbedding(Module):
|
||||
self.freeze()
|
||||
|
||||
def __call__(self, x):
|
||||
biases = self.get("biases")
|
||||
return mx.dequantize(
|
||||
self["weight"][x],
|
||||
scales=self["scales"][x],
|
||||
biases=self["biases"][x],
|
||||
biases=biases[x] if biases is not None else None,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
@@ -128,7 +131,7 @@ class QuantizedEmbedding(Module):
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
biases=self.get("biases"),
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
@@ -207,9 +210,11 @@ class QuantizedLinear(Module):
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
weight, group_size, bits, mode=mode
|
||||
)
|
||||
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
self.scales = scales_biases
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
@@ -231,7 +236,7 @@ class QuantizedLinear(Module):
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
biases=self.get("biases"),
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
@@ -252,12 +257,17 @@ class QuantizedLinear(Module):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
ql.weight, scales_biases = mx.quantize(
|
||||
linear_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
mode=mode,
|
||||
)
|
||||
if mode == "affine":
|
||||
ql.scales, ql.biases = scales_biases
|
||||
else:
|
||||
ql.scales = scales_biases
|
||||
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
||||
|
||||
@@ -4153,7 +4153,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
@@ -4161,7 +4161,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@@ -4172,7 +4172,8 @@ void init_ops(nb::module_& m) {
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
@@ -4220,11 +4221,11 @@ void init_ops(nb::module_& m) {
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
tuple: A tuple with either two or three elements containing:
|
||||
|
||||
* w_q (array): The quantized version of ``w``
|
||||
* scales (array): The scale to multiply each element with, namely :math:`s`
|
||||
* biases (array): The biases to add to each element, namely :math:`\beta`
|
||||
* scales (array): The quantization scales
|
||||
* biases (array): The quantization biases (returned for `mode=="affine"`).
|
||||
|
||||
Notes:
|
||||
The currently supported quantization mode is `"affine"`.
|
||||
@@ -4256,14 +4257,14 @@ void init_ops(nb::module_& m) {
|
||||
&mx::dequantize,
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
@@ -4272,7 +4273,8 @@ void init_ops(nb::module_& m) {
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
@@ -4298,7 +4300,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
@@ -4309,7 +4311,7 @@ void init_ops(nb::module_& m) {
|
||||
"sorted_indices"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
@@ -4325,7 +4327,8 @@ void init_ops(nb::module_& m) {
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w``. Default: ``None``.
|
||||
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
||||
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
|
||||
@@ -27,6 +27,56 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||
self.assertTrue(mx.all(a_hat == 0))
|
||||
|
||||
def test_mxfp4_quantize_dequantize(self):
|
||||
lut = mx.array(
|
||||
[
|
||||
+0.0,
|
||||
+0.5,
|
||||
+1.0,
|
||||
+1.5,
|
||||
+2.0,
|
||||
+3.0,
|
||||
+4.0,
|
||||
+6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
)
|
||||
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
|
||||
w = w.reshape(-1, 32)
|
||||
w[:, 0] = 6
|
||||
w = (w + 3e-6).astype(mx.bfloat16)
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, group_size=32, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=64, bits=4, mode="mxfp4")
|
||||
|
||||
w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4")
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
@@ -233,6 +283,71 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
def test_mode_error_cases(self):
|
||||
w = mx.random.normal(shape=(256, 256))
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
|
||||
# Invalid mode
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, mode="xyz")
|
||||
|
||||
wq, scales, biases = mx.quantize(w, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode="xyz")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(
|
||||
x, wq, scales, biases, bits=4, group_size=32, mode="xyz"
|
||||
)
|
||||
|
||||
rhs_indices = mx.array(0)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.gather_qmm(
|
||||
x,
|
||||
wq,
|
||||
scales,
|
||||
biases,
|
||||
rhs_indices=rhs_indices,
|
||||
bits=4,
|
||||
group_size=32,
|
||||
mode="xyz",
|
||||
)
|
||||
|
||||
# Only quantize floating point types
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(mx.zeros((128, 128), mx.int32))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(mx.zeros((128, 128), mx.int32), mode="mxfp4")
|
||||
|
||||
# Must have bias for affine
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(wq, scales, None, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.gather_qmm(
|
||||
x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32
|
||||
)
|
||||
|
||||
# Must be floating point
|
||||
x = mx.zeros(shape=(256,), dtype=mx.int32)
|
||||
scales = mx.zeros(scales.shape, dtype=mx.int32)
|
||||
biases = mx.zeros(scales.shape, dtype=mx.int32)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(wq, scales, biases, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.gather_qmm(
|
||||
x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32
|
||||
)
|
||||
|
||||
def test_throw(self):
|
||||
x = mx.random.normal(shape=(10, 512))
|
||||
w = mx.random.normal(shape=(32, 512))
|
||||
|
||||
Reference in New Issue
Block a user