mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 20:07:59 +08:00
Fix qmm_t for unaligned cases (#923)
This commit is contained in:
parent
46caf0bef0
commit
5f9ba3019f
@ -520,6 +520,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
const int K_g = K / group_size;
|
const int K_g = K / group_size;
|
||||||
const int y_row = tid.y * BM;
|
const int y_row = tid.y * BM;
|
||||||
const int y_col = tid.x * BN;
|
const int y_col = tid.x * BN;
|
||||||
|
|
||||||
x += y_row * K;
|
x += y_row * K;
|
||||||
w += y_col * K_w;
|
w += y_col * K_w;
|
||||||
scales += y_col * K_g;
|
scales += y_col * K_g;
|
||||||
@ -572,7 +573,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||||
|
|
||||||
if (y_row + offset_row < N) {
|
// y_col corresponds to the row of the weight matrix and added to
|
||||||
|
// offset_row it should be less than the total number of rows
|
||||||
|
// otherwise skip.
|
||||||
|
if (y_col + offset_row < N) {
|
||||||
uint32_t wi = *w_local;
|
uint32_t wi = *w_local;
|
||||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||||
|
@ -229,6 +229,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
# Test with larger than 128 unaligned sizes
|
||||||
|
w = mx.random.normal(shape=(99, 256))
|
||||||
|
w_q, scales, biases = mx.quantize(w)
|
||||||
|
w_hat = mx.dequantize(w_q, scales, biases)
|
||||||
|
x = mx.random.normal(shape=(129, 256))
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||||
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user