route more sizes to custom gemv

This commit is contained in:
Awni Hannun
2025-07-27 15:43:49 -07:00
parent 2c7fe67d56
commit 04b54d99a4

View File

@@ -74,26 +74,24 @@ __global__ void gemv_batched(
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0; bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0;
return is_multiple && return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
} }
template <typename F> template <typename F>
void dispatch_n_per_thread(int n_per_thread, F&& f) { void dispatch_n_per_thread(int n_per_thread, F&& f) {
switch (n_per_thread) { switch (n_per_thread) {
case 1:
f(std::integral_constant<int, 1>{});
break;
case 2: case 2:
f(std::integral_constant<int, 2>{}); f(std::integral_constant<int, 2>{});
break; break;
case 4: case 4:
f(std::integral_constant<int, 4>{}); f(std::integral_constant<int, 4>{});
break; break;
case 8:
f(std::integral_constant<int, 8>{});
break;
} }
} }
void gemv( void gemv(
const array& a, const array& a,
const array& b, const array& b,
@@ -130,9 +128,9 @@ void gemv(
rows = M; rows = M;
} }
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
int n_per_t = 16 / sizeof(DataType); int n_per_t = 4;
while (K % (n_per_t * WARP_SIZE) != 0) { while (K % (n_per_t * WARP_SIZE) != 0) {
n_per_t /= 2; n_per_t >>= 1;
} }
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) { if (batch_count == 1) {