mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-26 15:58:14 +08:00
@@ -35,8 +35,8 @@ struct GEMVKernel {
|
|||||||
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
SN == 8 || SN == 16 || SN == 32,
|
SN == 4 || SN == 8 || SN == 16 || SN == 32,
|
||||||
"gemv block must have a width of 8, 16, or 32");
|
"gemv block must have a width of 4, 8, 16, or 32");
|
||||||
|
|
||||||
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
||||||
// into blocks of (blockM, blockN) divided among threadgroups
|
// into blocks of (blockM, blockN) divided among threadgroups
|
||||||
@@ -511,17 +511,21 @@ template <
|
|||||||
axpby)
|
axpby)
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
|
||||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \
|
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
|
||||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \
|
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
|
||||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \
|
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
|
||||||
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on
|
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#define instantiate_gemv_blocks(name, itype) \
|
#define instantiate_gemv_blocks(name, itype) \
|
||||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \
|
||||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \
|
||||||
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
|
instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \
|
||||||
|
instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \
|
||||||
|
instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \
|
||||||
|
instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \
|
||||||
|
instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
|
||||||
|
|
||||||
instantiate_gemv_blocks(float32, float);
|
instantiate_gemv_blocks(float32, float);
|
||||||
instantiate_gemv_blocks(float16, half);
|
instantiate_gemv_blocks(float16, half);
|
||||||
|
@@ -698,6 +698,15 @@ void gemv_axbpy(
|
|||||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||||
sn = 32;
|
sn = 32;
|
||||||
|
|
||||||
|
if (K <= 64) {
|
||||||
|
bm = 1;
|
||||||
|
sm = 8;
|
||||||
|
sn = 4;
|
||||||
|
} else if (K >= 16 * out_vector_len) {
|
||||||
|
bm = 1;
|
||||||
|
bn = 8;
|
||||||
|
}
|
||||||
|
|
||||||
// Specialized kernel for very small outputs
|
// Specialized kernel for very small outputs
|
||||||
tm = out_vector_len < tm ? 1 : tm;
|
tm = out_vector_len < tm ? 1 : tm;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user