mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Fix qmm_t for unaligned cases (#923)
This commit is contained in:

committed by
GitHub

parent
46caf0bef0
commit
5f9ba3019f
@@ -229,6 +229,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test with larger than 128 unaligned sizes
|
||||
w = mx.random.normal(shape=(99, 256))
|
||||
w_q, scales, biases = mx.quantize(w)
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
x = mx.random.normal(shape=(129, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user