No quant reshape (#957)

* precise option on cpu

* remove print

* remove reshape in quant matmul

* no quant reshape
This commit is contained in:
Awni Hannun
2024-04-04 11:52:12 -07:00
committed by GitHub
parent d88d2124b5
commit 039da779d1
2 changed files with 34 additions and 18 deletions

View File

@@ -47,6 +47,36 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmm_vjp(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
bits = 8
group_size = 64
M = 64
N = 1024
K = 512
x = mx.random.normal(shape=(2, M, K), key=k1)
c = mx.ones(shape=(2, M, N))
transposes = [True, False]
for transposed in transposes:
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
def fn(x):
return mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
expected_out = mx.quantized_matmul(
c, w_q, scales, biases, not transposed, group_size, bits
)
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
def test_qmm_shapes(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)