New tuning for small K gemv (#2620)

* New tuning for small K gemv
This commit is contained in:
Jagrit Digani
2025-09-23 12:28:35 -07:00
committed by GitHub
parent fbbf3b9b3e
commit 7c7e48dbd1
2 changed files with 23 additions and 10 deletions

View File

@@ -35,8 +35,8 @@ struct GEMVKernel {
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
static_assert(
SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 8, 16, or 32");
SN == 4 || SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 4, 8, 16, or 32");
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
@@ -511,17 +511,21 @@ template <
axpby)
// clang-format off
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
// clang-format off
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \
instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \
instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \
instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \
instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);

View File

@@ -698,6 +698,15 @@ void gemv_axbpy(
bm = out_vector_len >= 4096 ? 8 : 4;
sn = 32;
if (K <= 64) {
bm = 1;
sm = 8;
sn = 4;
} else if (K >= 16 * out_vector_len) {
bm = 1;
bn = 8;
}
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;