Fix batched qmv bug (#1758)

This commit is contained in:
Alex Barron
2025-01-09 11:45:57 -08:00
committed by GitHub
parent da8c885784
commit c7b0300af5
2 changed files with 22 additions and 13 deletions

View File

@@ -212,11 +212,12 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat = mx.dequantize(w_q, scales, biases)
# Test qmv
x = mx.random.normal(shape=(3, 1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
for shape in [(3, 1, 256), (3, 4, 256)]:
x = mx.random.normal(shape=shape)
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t
x = mx.random.normal(shape=(3, 10, 256))