mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices * Fix qmm
This commit is contained in:
parent
4cc70290f7
commit
40c108766b
@ -39,11 +39,12 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
|
|
||||||
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
||||||
|
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int el_per_thread = 32 / bits;
|
constexpr int el_per_thread = 32 / bits;
|
||||||
constexpr int colgroup = BN * el_per_thread;
|
constexpr int colgroup = BN * el_per_thread;
|
||||||
constexpr int groups_per_block = colgroup / group_size;
|
constexpr int groups_per_block = colgroup / group_size;
|
||||||
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
|
||||||
|
|
||||||
typedef typename AccT<T>::acc_t U;
|
typedef typename AccT<T>::acc_t U;
|
||||||
threadgroup U scales_block[BM * groups_per_block];
|
threadgroup U scales_block[BM * groups_per_block];
|
||||||
@ -66,12 +67,19 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
|||||||
x += tid.z * in_vec_size;
|
x += tid.z * in_vec_size;
|
||||||
y += tid.z * out_vec_size;
|
y += tid.z * out_vec_size;
|
||||||
|
|
||||||
|
if (out_row >= out_vec_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Loop over in_vec in blocks of colgroup
|
// Loop over in_vec in blocks of colgroup
|
||||||
for (int i=0; i<in_vec_size; i+=colgroup) {
|
for (int i=0; i<in_vec_size; i+=colgroup) {
|
||||||
// Load the vec to shared memory
|
// Load the vec to shared memory
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (simd_gid < simdgroups_fetching_vec) {
|
if (simd_gid == 0) {
|
||||||
x_block[lid] = x[lid + i];
|
#pragma clang loop unroll(full)
|
||||||
|
for (int j=0; j<el_per_thread; j++) {
|
||||||
|
x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@ -250,7 +258,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
|
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
|
||||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||||
|
|
||||||
|
|
||||||
threadgroup T scales_block[BN * groups_per_block];
|
threadgroup T scales_block[BN * groups_per_block];
|
||||||
threadgroup T biases_block[BN * groups_per_block];
|
threadgroup T biases_block[BN * groups_per_block];
|
||||||
threadgroup T Xs[BM * BK];
|
threadgroup T Xs[BM * BK];
|
||||||
@ -313,7 +320,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||||
|
|
||||||
if (y_col + offset_col < N) {
|
if (y_row + offset_row < N) {
|
||||||
uint32_t wi = *w_local;
|
uint32_t wi = *w_local;
|
||||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||||
@ -428,8 +435,9 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
for (int k=0; k<K; k += BK) {
|
for (int k=0; k<K; k += BK) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Load the x tile
|
// Load the x tile
|
||||||
if (num_els < BM) {
|
short num_k = min(BK, K - k);
|
||||||
loader_x.load_safe(short2(BK, num_els));
|
if (num_els < BM || num_k < BK) {
|
||||||
|
loader_x.load_safe(short2(num_k, num_els));
|
||||||
} else {
|
} else {
|
||||||
loader_x.load_unsafe();
|
loader_x.load_unsafe();
|
||||||
}
|
}
|
||||||
@ -457,7 +465,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
|
|
||||||
// Load the w tile
|
// Load the w tile
|
||||||
{
|
{
|
||||||
if (k + BK >= K) {
|
if (num_k < BK) {
|
||||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||||
int offset = lid * w_els_per_thread + wo;
|
int offset = lid * w_els_per_thread + wo;
|
||||||
int offset_row = offset / (BN / el_per_int);
|
int offset_row = offset / (BN / el_per_int);
|
||||||
|
@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int bo = std::min(32, O);
|
int bo = std::min(32, O);
|
||||||
int bd = 32;
|
int bd = 32;
|
||||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||||
|
|
||||||
set_array_buffer(compute_encoder, w, 0);
|
set_array_buffer(compute_encoder, w, 0);
|
||||||
set_array_buffer(compute_encoder, scales, 1);
|
set_array_buffer(compute_encoder, scales, 1);
|
||||||
|
@ -165,6 +165,70 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
def test_non_multiples(self):
|
||||||
|
w = mx.random.normal(shape=(33, 256))
|
||||||
|
w_q, scales, biases = mx.quantize(w)
|
||||||
|
w_hat = mx.dequantize(w_q, scales, biases)
|
||||||
|
|
||||||
|
# Test qmv
|
||||||
|
x = mx.random.normal(shape=(1, 256))
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||||
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
# Test qmm_t
|
||||||
|
x = mx.random.normal(shape=(10, 256))
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||||
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
# Test qvm
|
||||||
|
x = mx.random.normal(shape=(1, 33))
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||||
|
y_hat = x @ w_hat
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
# Test qmm
|
||||||
|
x = mx.random.normal(shape=(10, 33))
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||||
|
y_hat = x @ w_hat
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
# Smaller than 8
|
||||||
|
w = mx.random.normal(shape=(3, 256))
|
||||||
|
w_q, scales, biases = mx.quantize(w)
|
||||||
|
w_hat = mx.dequantize(w_q, scales, biases)
|
||||||
|
|
||||||
|
# Test qmv
|
||||||
|
x = mx.random.normal(shape=(1, 256))
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||||
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
# Test qmm_t
|
||||||
|
x = mx.random.normal(shape=(10, 256))
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||||
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
# Test qvm
|
||||||
|
x = mx.random.normal(shape=(1, 3))
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||||
|
y_hat = x @ w_hat
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
# Test qmm
|
||||||
|
x = mx.random.normal(shape=(10, 3))
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||||
|
y_hat = x @ w_hat
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user