mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
Improve perf
This commit is contained in:
parent
7df3a2887d
commit
c7cdd51f50
@ -225,6 +225,25 @@ inline U qdot(
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
inline float qdot_bit4(
|
||||
const device uint16_t* w,
|
||||
const thread float* x_thread,
|
||||
float scale,
|
||||
float bias,
|
||||
float sum) {
|
||||
|
||||
float accum = 0;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * (w[i] & 0x000f) +
|
||||
x_thread[4 * i + 1] * (w[i] & 0x00f0) +
|
||||
x_thread[4 * i + 2] * (w[i] & 0x0f00) +
|
||||
x_thread[4 * i + 3] * (w[i] & 0xf000));
|
||||
}
|
||||
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot_safe(
|
||||
const device uint8_t* w,
|
||||
@ -728,35 +747,40 @@ METAL_FUNC void qmv_no_parallel_m_impl(
|
||||
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
// x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
||||
// y += tid.x * out_vec_size + out_row;
|
||||
x += simd_lid * values_per_thread;
|
||||
y += out_row;
|
||||
|
||||
for (int k = 0; k < k_size; k += block_size) {
|
||||
// U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
const device uint16_t* wb = (const device uint16_t*)wl;
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
|
||||
for (int col = 0; col < m_size; col++) {
|
||||
auto x_temp = x + col * k_size + simd_lid * values_per_thread + k;
|
||||
auto x_temp = x + col * k_size;
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x_temp, x_thread);
|
||||
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
if (bits == 4) {
|
||||
result[col * results_per_simdgroup + row] += qdot_bit4(wb, x_thread, s, b, sum);
|
||||
} else {
|
||||
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
scales += block_size / group_size;
|
||||
biases += block_size / group_size;
|
||||
// x += block_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
for (int col = 0; col < m_size; col++) {
|
||||
result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]);
|
||||
auto y_temp = y + col * n_size + out_row;
|
||||
auto y_temp = y + col * n_size;
|
||||
if (simd_lid == 0) {
|
||||
y_temp[row] = static_cast<T>(result[col * results_per_simdgroup + row]);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user