mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
No quant reshape (#957)
* precise option on cpu * remove print * remove reshape in quant matmul * no quant reshape
This commit is contained in:
@@ -47,6 +47,36 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qmm_vjp(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
|
||||
bits = 8
|
||||
group_size = 64
|
||||
M = 64
|
||||
N = 1024
|
||||
K = 512
|
||||
|
||||
x = mx.random.normal(shape=(2, M, K), key=k1)
|
||||
c = mx.ones(shape=(2, M, N))
|
||||
|
||||
transposes = [True, False]
|
||||
for transposed in transposes:
|
||||
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
|
||||
def fn(x):
|
||||
return mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transposed, group_size, bits
|
||||
)
|
||||
|
||||
_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
|
||||
|
||||
expected_out = mx.quantized_matmul(
|
||||
c, w_q, scales, biases, not transposed, group_size, bits
|
||||
)
|
||||
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
|
||||
|
||||
def test_qmm_shapes(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
|
Reference in New Issue
Block a user