mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Add split_k qvm
for long context (#1564)
* Add splitk qvm * configurable splitk * tuning * remove extra instantiation * remove refactor * separate test * cpu tolerance
This commit is contained in:
@@ -163,6 +163,31 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qvm_splitk(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[128], # M
|
||||
[16384], # N
|
||||
[1, 3], # B
|
||||
)
|
||||
for group_size, bits, M, N, B in tests:
|
||||
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, False, group_size, bits
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
def test_throw(self):
|
||||
x = mx.random.normal(shape=(10, 512))
|
||||
w = mx.random.normal(shape=(32, 512))
|
||||
|
Reference in New Issue
Block a user