mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
Fix batched qmv bug (#1758)
This commit is contained in:
@@ -212,11 +212,12 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(3, 1, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
for shape in [(3, 1, 256), (3, 4, 256)]:
|
||||
x = mx.random.normal(shape=shape)
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(3, 10, 256))
|
||||
|
Reference in New Issue
Block a user