mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 22:44:38 +08:00
Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation * Add qmm_n * Add gradient wrt to input for quantized_matmul
This commit is contained in:

committed by
GitHub

parent
d7ac050f4b
commit
e7f5059fe4
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@@ -19,62 +20,116 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
for group_size in [128, 64]:
|
||||
for bits in [2, 4, 8]:
|
||||
for M in [8, 32, 33, 64]:
|
||||
for N in [512, 1024]:
|
||||
for K in [512, 1024]:
|
||||
with self.subTest(
|
||||
shape=(M, N, K), group_size=group_size, bits=bits
|
||||
):
|
||||
x = mx.random.normal(shape=(M, K), key=k1)
|
||||
w = mx.random.normal(shape=(N, K), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(
|
||||
w_q, scales, biases, group_size, bits
|
||||
)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, group_size, bits
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
tests = product(
|
||||
[128, 64], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[8, 32, 33, 64], # M
|
||||
[512, 1024], # N
|
||||
[512, 1024], # K
|
||||
[True, False], # transposed
|
||||
)
|
||||
for group_size, bits, M, N, K, transposed in tests:
|
||||
with self.subTest(
|
||||
shape=(M, N, K),
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
transposed=transposed,
|
||||
):
|
||||
x = mx.random.normal(shape=(M, K), key=k1)
|
||||
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transposed, group_size, bits
|
||||
)
|
||||
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qmm_shapes(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
group_size = 64
|
||||
bits = 4
|
||||
w = mx.random.normal(shape=(32, 128), key=k2)
|
||||
w = mx.random.normal(shape=(32, 256), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
for s in [(3, 128), (2, 1, 7, 128)]:
|
||||
x = mx.random.normal(shape=(3, 128), key=k1)
|
||||
y_q = mx.quantized_matmul(x, w_q.T, scales, biases, group_size, bits)
|
||||
for s in [(3, 256), (2, 1, 7, 256)]:
|
||||
x = mx.random.normal(shape=s, key=k1)
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, True, group_size, bits)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
w = mx.random.normal(shape=(256, 256), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
for s in [(3, 256), (2, 1, 7, 256)]:
|
||||
x = mx.random.normal(shape=s, key=k1)
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qmv(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
for group_size in [128, 64]:
|
||||
for bits in [2, 4, 8]:
|
||||
for M in [512, 1024]:
|
||||
for N in [512, 1024]:
|
||||
with self.subTest(
|
||||
shape=(M, N), group_size=group_size, bits=bits
|
||||
):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q.T, scales, biases, group_size, bits
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
tests = product(
|
||||
[128, 64], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, True, group_size, bits
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qvm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(N, M), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, False, group_size, bits
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_throw(self):
|
||||
x = mx.random.normal(shape=(10, 512))
|
||||
w = mx.random.normal(shape=(32, 512))
|
||||
w_q, scales, biases = mx.quantize(w)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, w_q.T, scales, biases)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, w_q.T, scales.T, biases)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, w_q, scales, biases, False)
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantized_matmul(x, w_q, scales.T, biases.T)
|
||||
y = mx.quantized_matmul(x, w_q, scales, biases, True)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user