mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-26 15:58:14 +08:00
@@ -35,8 +35,8 @@ struct GEMVKernel {
|
||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||
|
||||
static_assert(
|
||||
SN == 8 || SN == 16 || SN == 32,
|
||||
"gemv block must have a width of 8, 16, or 32");
|
||||
SN == 4 || SN == 8 || SN == 16 || SN == 32,
|
||||
"gemv block must have a width of 4, 8, 16, or 32");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
||||
// into blocks of (blockM, blockN) divided among threadgroups
|
||||
@@ -511,17 +511,21 @@ template <
|
||||
axpby)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on
|
||||
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
|
||||
instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \
|
||||
instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
|
||||
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
|
@@ -698,6 +698,15 @@ void gemv_axbpy(
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
sn = 32;
|
||||
|
||||
if (K <= 64) {
|
||||
bm = 1;
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
} else if (K >= 16 * out_vector_len) {
|
||||
bm = 1;
|
||||
bn = 8;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
|
Reference in New Issue
Block a user