Improve perf

This commit is contained in:
TianyiZhao1437 2025-07-23 10:12:38 +08:00
parent 7df3a2887d
commit c7cdd51f50

View File

@ -225,6 +225,25 @@ inline U qdot(
return scale * accum + sum * bias; return scale * accum + sum * bias;
} }
inline float qdot_bit4(
const device uint16_t* w,
const thread float* x_thread,
float scale,
float bias,
float sum) {
float accum = 0;
for (int i = 0; i < 4; i++) {
accum +=
(x_thread[4 * i] * (w[i] & 0x000f) +
x_thread[4 * i + 1] * (w[i] & 0x00f0) +
x_thread[4 * i + 2] * (w[i] & 0x0f00) +
x_thread[4 * i + 3] * (w[i] & 0xf000));
}
return scale * accum + sum * bias;
}
template <typename U, int values_per_thread, int bits> template <typename U, int values_per_thread, int bits>
inline U qdot_safe( inline U qdot_safe(
const device uint8_t* w, const device uint8_t* w,
@ -728,35 +747,40 @@ METAL_FUNC void qmv_no_parallel_m_impl(
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
// x += tid.x * in_vec_size + simd_lid * values_per_thread; x += simd_lid * values_per_thread;
// y += tid.x * out_vec_size + out_row; y += out_row;
for (int k = 0; k < k_size; k += block_size) { for (int k = 0; k < k_size; k += block_size) {
// U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); // U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device uint16_t* wb = (const device uint16_t*)wl;
const device T* sl = scales + row * in_vec_size_g; const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
for (int col = 0; col < m_size; col++) { for (int col = 0; col < m_size; col++) {
auto x_temp = x + col * k_size + simd_lid * values_per_thread + k; auto x_temp = x + col * k_size;
U sum = load_vector<T, U, values_per_thread, bits>(x_temp, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x_temp, x_thread);
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); if (bits == 4) {
result[col * results_per_simdgroup + row] += qdot_bit4(wb, x_thread, s, b, sum);
} else {
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
} }
} }
ws += block_size * bytes_per_pack / pack_factor; ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size; scales += block_size / group_size;
biases += block_size / group_size; biases += block_size / group_size;
// x += block_size; x += block_size;
} }
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
for (int col = 0; col < m_size; col++) { for (int col = 0; col < m_size; col++) {
result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]); result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]);
auto y_temp = y + col * n_size + out_row; auto y_temp = y + col * n_size;
if (simd_lid == 0) { if (simd_lid == 0) {
y_temp[row] = static_cast<T>(result[col * results_per_simdgroup + row]); y_temp[row] = static_cast<T>(result[col * results_per_simdgroup + row]);
} }