[Feature]Add no parallel-m qmm kernel to improve decoding performance

This commit is contained in:
tianyi 2025-07-22 12:54:04 +08:00
parent b529515eb1
commit b2f0ebe9ee
3 changed files with 184 additions and 1 deletions

View File

@ -688,6 +688,82 @@ METAL_FUNC void qmv_fast_impl(
} }
} }
template <typename T, int group_size, int bits>
METAL_FUNC void qmv_no_parallel_m_impl(
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
const constant int& m_size,
const constant int& k_size,
const constant int& n_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
constexpr int max_batch = 10;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[max_batch * results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = k_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = k_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
// x += tid.x * in_vec_size + simd_lid * values_per_thread;
// y += tid.x * out_vec_size + out_row;
for (int k = 0; k < k_size; k += block_size) {
// U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
for (int col = 0; col < m_size; col++) {
auto x_temp = x + col * k_size + simd_lid * values_per_thread + k;
U sum = load_vector<T, U, values_per_thread, bits>(x_temp, x_thread);
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
}
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
// x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
for (int col = 0; col < m_size; col++) {
result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]);
auto y_temp = y + col * n_size + out_row;
if (simd_lid == 0) {
y_temp[row] = static_cast<T>(result[col * results_per_simdgroup + row]);
}
}
}
}
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
METAL_FUNC void qmv_impl( METAL_FUNC void qmv_impl(
const device uint32_t* w, const device uint32_t* w,
@ -1410,6 +1486,59 @@ template <typename T, int group_size, int bits, bool batched>
simd_lid); simd_lid);
} }
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void qmv_no_parallel_m(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& m_size [[buffer(5)]],
const constant int& k_size [[buffer(6)]],
const constant int& n_size [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant int64_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant int64_t* w_strides [[buffer(13)]],
const constant int64_t* s_strides [[buffer(14)]],
const constant int64_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
n_size * m_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_no_parallel_m_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
m_size,
k_size,
n_size,
tid,
simd_gid,
simd_lid);
}
template <typename T, const int group_size, const int bits, bool batched> template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qmv( [[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],

View File

@ -80,6 +80,7 @@
#define instantiate_quantized_all_batched(type, group_size, bits) \ #define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \ instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_no_parallel_m, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \ instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \ instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits) instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)

View File

@ -244,6 +244,59 @@ void qmv(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void qmv_no_parallel_m(
const array& x,
const array& w,
const array& scales,
const array& biases,
array& out,
int group_size,
int bits,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
int B = out.size() / M / N;
int bn = 128;
// int bk = 32;
MTL::Size group_dims(2, 1, 1);
MTL::Size grid_dims((N + bn - 1) / bn, 1, B);
std::string kname;
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
// bool fast = N % bn == 0 && K % 512 == 0;
concatenate(
kname,
"qmv_no_parallel_m_",
type_string,
"_gs_",
group_size,
"_b_",
bits,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, "qmv_no_parallel_m", type_string, group_size, bits, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(M, 5);
compute_encoder.set_bytes(K, 6);
compute_encoder.set_bytes(N, 7);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void qvm_split_k( void qvm_split_k(
const array& x, const array& x,
const array& w, const array& w,
@ -818,7 +871,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Run of the mill qmv // Run of the mill qmv
if (transpose_) { if (transpose_) {
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); qmv_no_parallel_m(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
return; return;
} }