mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster small batch qmv (#1861)
* faster small batch qmv * swap batch and block dims for qvm and qmv regular
This commit is contained in:
@@ -648,14 +648,14 @@ METAL_FUNC void qmv_fast_impl(
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
|
||||
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.x * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
@@ -716,7 +716,7 @@ METAL_FUNC void qmv_impl(
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
||||
|
||||
@@ -731,8 +731,8 @@ METAL_FUNC void qmv_impl(
|
||||
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.x * out_vec_size + out_row;
|
||||
|
||||
int k = 0;
|
||||
for (; k < in_vec_size - block_size; k += block_size) {
|
||||
@@ -788,8 +788,8 @@ METAL_FUNC void qmv_impl(
|
||||
simd_lid * packs_per_thread * bytes_per_pack;
|
||||
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + used_out_row;
|
||||
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.x * out_vec_size + used_out_row;
|
||||
|
||||
int k = 0;
|
||||
for (; k < in_vec_size - block_size; k += block_size) {
|
||||
@@ -876,12 +876,12 @@ METAL_FUNC void qvm_impl(
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid);
|
||||
int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);
|
||||
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
|
||||
scales += out_col / group_size + simd_lid * out_vec_size_g;
|
||||
biases += out_col / group_size + simd_lid * out_vec_size_g;
|
||||
x += tid.y * in_vec_size + simd_lid;
|
||||
y += tid.y * out_vec_size + out_col;
|
||||
x += tid.x * in_vec_size + simd_lid;
|
||||
y += tid.x * out_vec_size + out_col;
|
||||
|
||||
if (out_col >= out_vec_size) {
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user