tune quant dispatch (#2031)

This commit is contained in:
Awni Hannun
2025-04-02 20:05:54 -07:00
committed by GitHub
parent c23888acd7
commit 9ba81e3da4
2 changed files with 53 additions and 7 deletions

View File

@@ -586,13 +586,13 @@ METAL_FUNC void qmv_quad_impl(
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
x += tid.x * in_vec_size + quad_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);