mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 07:58:41 +08:00
Improve perf
This commit is contained in:
parent
7df3a2887d
commit
c7cdd51f50
@ -225,6 +225,25 @@ inline U qdot(
|
|||||||
return scale * accum + sum * bias;
|
return scale * accum + sum * bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline float qdot_bit4(
|
||||||
|
const device uint16_t* w,
|
||||||
|
const thread float* x_thread,
|
||||||
|
float scale,
|
||||||
|
float bias,
|
||||||
|
float sum) {
|
||||||
|
|
||||||
|
float accum = 0;
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
accum +=
|
||||||
|
(x_thread[4 * i] * (w[i] & 0x000f) +
|
||||||
|
x_thread[4 * i + 1] * (w[i] & 0x00f0) +
|
||||||
|
x_thread[4 * i + 2] * (w[i] & 0x0f00) +
|
||||||
|
x_thread[4 * i + 3] * (w[i] & 0xf000));
|
||||||
|
}
|
||||||
|
|
||||||
|
return scale * accum + sum * bias;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline U qdot_safe(
|
inline U qdot_safe(
|
||||||
const device uint8_t* w,
|
const device uint8_t* w,
|
||||||
@ -728,35 +747,40 @@ METAL_FUNC void qmv_no_parallel_m_impl(
|
|||||||
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
// x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
x += simd_lid * values_per_thread;
|
||||||
// y += tid.x * out_vec_size + out_row;
|
y += out_row;
|
||||||
|
|
||||||
for (int k = 0; k < k_size; k += block_size) {
|
for (int k = 0; k < k_size; k += block_size) {
|
||||||
// U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
// U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
|
const device uint16_t* wb = (const device uint16_t*)wl;
|
||||||
const device T* sl = scales + row * in_vec_size_g;
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
|
|
||||||
for (int col = 0; col < m_size; col++) {
|
for (int col = 0; col < m_size; col++) {
|
||||||
auto x_temp = x + col * k_size + simd_lid * values_per_thread + k;
|
auto x_temp = x + col * k_size;
|
||||||
U sum = load_vector<T, U, values_per_thread, bits>(x_temp, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x_temp, x_thread);
|
||||||
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
if (bits == 4) {
|
||||||
|
result[col * results_per_simdgroup + row] += qdot_bit4(wb, x_thread, s, b, sum);
|
||||||
|
} else {
|
||||||
|
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ws += block_size * bytes_per_pack / pack_factor;
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
scales += block_size / group_size;
|
scales += block_size / group_size;
|
||||||
biases += block_size / group_size;
|
biases += block_size / group_size;
|
||||||
// x += block_size;
|
x += block_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
for (int col = 0; col < m_size; col++) {
|
for (int col = 0; col < m_size; col++) {
|
||||||
result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]);
|
result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]);
|
||||||
auto y_temp = y + col * n_size + out_row;
|
auto y_temp = y + col * n_size;
|
||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
y_temp[row] = static_cast<T>(result[col * results_per_simdgroup + row]);
|
y_temp[row] = static_cast<T>(result[col * results_per_simdgroup + row]);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user