mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
tune quant dispatch (#2031)
This commit is contained in:
@@ -586,13 +586,13 @@ METAL_FUNC void qmv_quad_impl(
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
|
||||
const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
|
||||
|
||||
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
|
||||
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
||||
x += tid.y * in_vec_size + quad_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
x += tid.x * in_vec_size + quad_lid * values_per_thread;
|
||||
y += tid.x * out_vec_size + out_row;
|
||||
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user