diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index baaf84f2d..b04545152 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -35,8 +35,8 @@ struct GEMVKernel { static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); static_assert( - SN == 8 || SN == 16 || SN == 32, - "gemv block must have a width of 8, 16, or 32"); + SN == 4 || SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 4, 8, 16, or 32"); // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up // into blocks of (blockM, blockN) divided among threadgroups @@ -511,17 +511,21 @@ template < axpby) // clang-format off -#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ - instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \ - instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \ - instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \ - instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on +#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \ + instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \ + instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \ + instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \ + instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on // clang-format off #define instantiate_gemv_blocks(name, itype) \ - instantiate_gemv(name, itype, 4, 32, 1, 4) \ - instantiate_gemv(name, itype, 4, 32, 4, 4) \ - instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on + instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \ + instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \ + instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \ + instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \ + instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \ + instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \ + instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float16, half); diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 5a185f416..1ceae3f33 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -698,6 +698,15 @@ void gemv_axbpy( bm = out_vector_len >= 4096 ? 8 : 4; sn = 32; + if (K <= 64) { + bm = 1; + sm = 8; + sn = 4; + } else if (K >= 16 * out_vector_len) { + bm = 1; + bn = 8; + } + // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm;