Fix qmm_t for unaligned cases (#923)

This commit is contained in:
Angelos Katharopoulos
2024-03-28 15:34:57 -07:00
committed by GitHub
parent 46caf0bef0
commit 5f9ba3019f
2 changed files with 15 additions and 1 deletions

View File

@@ -229,6 +229,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test with larger than 128 unaligned sizes
w = mx.random.normal(shape=(99, 256))
w_q, scales, biases = mx.quantize(w)
w_hat = mx.dequantize(w_q, scales, biases)
x = mx.random.normal(shape=(129, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
if __name__ == "__main__":
unittest.main()