Improve perf

This commit is contained in:
TianyiZhao1437 2025-07-23 10:12:38 +08:00
parent 7df3a2887d
commit c7cdd51f50

View File

@ -225,6 +225,25 @@ inline U qdot(
return scale * accum + sum * bias;
}
inline float qdot_bit4(
const device uint16_t* w,
const thread float* x_thread,
float scale,
float bias,
float sum) {
float accum = 0;
for (int i = 0; i < 4; i++) {
accum +=
(x_thread[4 * i] * (w[i] & 0x000f) +
x_thread[4 * i + 1] * (w[i] & 0x00f0) +
x_thread[4 * i + 2] * (w[i] & 0x0f00) +
x_thread[4 * i + 3] * (w[i] & 0xf000));
}
return scale * accum + sum * bias;
}
template <typename U, int values_per_thread, int bits>
inline U qdot_safe(
const device uint8_t* w,
@ -728,35 +747,40 @@ METAL_FUNC void qmv_no_parallel_m_impl(
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
// x += tid.x * in_vec_size + simd_lid * values_per_thread;
// y += tid.x * out_vec_size + out_row;
x += simd_lid * values_per_thread;
y += out_row;
for (int k = 0; k < k_size; k += block_size) {
// U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device uint16_t* wb = (const device uint16_t*)wl;
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
for (int col = 0; col < m_size; col++) {
auto x_temp = x + col * k_size + simd_lid * values_per_thread + k;
auto x_temp = x + col * k_size;
U sum = load_vector<T, U, values_per_thread, bits>(x_temp, x_thread);
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
if (bits == 4) {
result[col * results_per_simdgroup + row] += qdot_bit4(wb, x_thread, s, b, sum);
} else {
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
}
}
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
// x += block_size;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
for (int col = 0; col < m_size; col++) {
result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]);
auto y_temp = y + col * n_size + out_row;
auto y_temp = y + col * n_size;
if (simd_lid == 0) {
y_temp[row] = static_cast<T>(result[col * results_per_simdgroup + row]);
}