mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
An initial quantized matmul implementation (#205)
* Add quantized matvec * Add quantized matrix matrix with 2nd matrix transposed * Add quantized matmul tests * Add a slow cpu quantized matmul * Add a slightly faster vectorized cpu version
This commit is contained in:

committed by
GitHub

parent
e6872a4149
commit
dfa9f4bc58
@@ -2977,4 +2977,36 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): An array of the same type as ``a`` rounded to the given number of decimals.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"quantized_matmul",
|
||||
&quantized_matmul,
|
||||
"x"_a,
|
||||
"w"_a,
|
||||
py::pos_only(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"groups"_a = 128,
|
||||
"width"_a = 4,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
quantized_matmul(x: array, w: array, scales: array, biases: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``groups`` of
|
||||
elements. Each element in ``w`` takes ``width`` bits and is packed in an
|
||||
unsigned 32 bit integer.
|
||||
|
||||
Args:
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``groups`` elements of ``w``
|
||||
biases (array): The biases to use per ``groups`` elements of ``w``
|
||||
groups (int): The size of the group in ``w`` that shares a scale and
|
||||
bias. (default: 128)
|
||||
width (int): The bitwidth of the elements in ``w``. (default: 4)
|
||||
|
||||
Returns:
|
||||
result (array): The result of the multiplication of ``x`` with ``w``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
112
python/tests/test_quantized.py
Normal file
112
python/tests/test_quantized.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
def select_bits(w, width, start):
|
||||
shift_left = 32 - (start + width)
|
||||
shift_right = shift_left + start
|
||||
return (w * (2**shift_left)) // (2**shift_right)
|
||||
|
||||
|
||||
def dequantize(w, scales, biases, width):
|
||||
w_full = mx.concatenate(
|
||||
[select_bits(w, width, i)[..., None] for i in range(0, 32, width)], axis=-1
|
||||
)
|
||||
w_full = w_full.reshape(len(w), scales.shape[-1], -1)
|
||||
w_full = scales[..., None] * w_full + biases[..., None]
|
||||
w_full = w_full.reshape(len(w), -1)
|
||||
|
||||
return w_full
|
||||
|
||||
|
||||
def quantize(w, width, groups):
|
||||
w = w.reshape(len(w), -1, groups)
|
||||
w_max = w.max(-1, keepdims=True)
|
||||
w_min = w.min(-1, keepdims=True)
|
||||
delta = (w_max - w_min) / (2**width - 1)
|
||||
|
||||
w_int = mx.round((w - w_min) / delta).astype(mx.uint32)
|
||||
scales = delta.squeeze(-1)
|
||||
biases = w_min.squeeze(-1)
|
||||
|
||||
shifts = mx.array([2**i for i in range(0, 32, width)], dtype=mx.uint32)
|
||||
w_int = w_int.reshape(len(w), -1, 32 // width)
|
||||
w_int = w_int * shifts[None, None]
|
||||
packed_w = w_int.sum(-1)
|
||||
|
||||
return packed_w, scales, biases
|
||||
|
||||
|
||||
class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_quantize_dequantize(self):
|
||||
w = mx.random.normal(shape=(128, 128))
|
||||
w_q, scales, biases = quantize(w, 4, 64)
|
||||
w_hat = dequantize(w_q, scales, biases, 4)
|
||||
w_hat2 = dequantize(*quantize(w_hat, 4, 64), 4)
|
||||
self.assertLess((w_hat - w_hat2).abs().max(), 1e-6)
|
||||
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
for groups in [128, 64]:
|
||||
for width in [2, 4, 8]:
|
||||
for M in [8, 32, 33, 64]:
|
||||
for N in [512, 1024]:
|
||||
for K in [512, 1024]:
|
||||
with self.subTest(
|
||||
shape=(M, N, K), groups=groups, width=width
|
||||
):
|
||||
x = mx.random.normal(shape=(M, K), key=k1)
|
||||
w = mx.random.normal(shape=(N, K), key=k2)
|
||||
w_q, scales, biases = quantize(w, width, groups)
|
||||
w_hat = dequantize(w_q, scales, biases, width)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, width=width, groups=groups
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||
|
||||
def test_qmm_shapes(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
groups = 64
|
||||
width = 4
|
||||
w = mx.random.normal(shape=(32, 128), key=k2)
|
||||
w_q, scales, biases = quantize(w, width, groups)
|
||||
w_hat = dequantize(w_q, scales, biases, width)
|
||||
for s in [(3, 128), (2, 1, 7, 128)]:
|
||||
x = mx.random.normal(shape=(3, 128), key=k1)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, width=width, groups=groups
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||
|
||||
def test_qmv(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
for groups in [128, 64]:
|
||||
for width in [2, 4, 8]:
|
||||
for M in [512, 1024]:
|
||||
for N in [512, 1024]:
|
||||
# with self.subTest(shape=(M, N), groups=groups, width=width):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = quantize(w, width, groups)
|
||||
w_hat = dequantize(w_q, scales, biases, width)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, width=width, groups=groups
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user