Faster small batch qmv (#1861)

* faster small batch qmv

* swap batch and block dims for qvm and qmv regular
This commit is contained in:
Awni Hannun 2025-02-12 22:02:36 -08:00 committed by GitHub
parent d274ae77f2
commit e425dc00c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 15 deletions

View File

@ -648,14 +648,14 @@ METAL_FUNC void qmv_fast_impl(
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup; simd_gid * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row; y += tid.x * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) { for (int k = 0; k < in_vec_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
@ -716,7 +716,7 @@ METAL_FUNC void qmv_impl(
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup; simd_gid * results_per_simdgroup;
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
@ -731,8 +731,8 @@ METAL_FUNC void qmv_impl(
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row; y += tid.x * out_vec_size + out_row;
int k = 0; int k = 0;
for (; k < in_vec_size - block_size; k += block_size) { for (; k < in_vec_size - block_size; k += block_size) {
@ -788,8 +788,8 @@ METAL_FUNC void qmv_impl(
simd_lid * packs_per_thread * bytes_per_pack; simd_lid * packs_per_thread * bytes_per_pack;
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread; x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + used_out_row; y += tid.x * out_vec_size + used_out_row;
int k = 0; int k = 0;
for (; k < in_vec_size - block_size; k += block_size) { for (; k < in_vec_size - block_size; k += block_size) {
@ -876,12 +876,12 @@ METAL_FUNC void qvm_impl(
// Adjust positions // Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_g = out_vec_size / group_size; const int out_vec_size_g = out_vec_size / group_size;
int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid); int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
scales += out_col / group_size + simd_lid * out_vec_size_g; scales += out_col / group_size + simd_lid * out_vec_size_g;
biases += out_col / group_size + simd_lid * out_vec_size_g; biases += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size + simd_lid; x += tid.x * in_vec_size + simd_lid;
y += tid.y * out_vec_size + out_col; y += tid.x * out_vec_size + out_col;
if (out_col >= out_vec_size) { if (out_col >= out_vec_size) {
return; return;

View File

@ -167,7 +167,7 @@ void qvm_split_k(
int bo = 64; int bo = 64;
int bd = 32; int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1); MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N); MTL::Size grid_dims = MTL::Size(B, O / bo, N);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
@ -313,13 +313,13 @@ void qmm_op(
int bo = 8; int bo = 8;
int bd = 32; int bd = 32;
group_dims = MTL::Size(bd, 2, 1); group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size(O / bo, B, N); grid_dims = MTL::Size(B, O / bo, N);
} else if (B < 6) { } else if (B < 6) {
name += "qmv"; name += "qmv";
int bo = 8; int bo = 8;
int bd = 32; int bd = 32;
group_dims = MTL::Size(bd, 2, 1); group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size((O + bo - 1) / bo, B, N); grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
} else { } else {
int wn = 2; int wn = 2;
int wm = 2; int wm = 2;
@ -339,7 +339,7 @@ void qmm_op(
int bo = 64; int bo = 64;
int bd = 32; int bd = 32;
group_dims = MTL::Size(bd, 2, 1); group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size(O / bo, B, N); grid_dims = MTL::Size(B, O / bo, N);
} else { } else {
name += "qmm_n"; name += "qmm_n";
int wn = 2; int wn = 2;