From 2c7fe67d56cec1cadf058d2d66c260afa4555ce4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 27 Jul 2025 15:30:47 -0700 Subject: [PATCH] route more to gemv --- mlx/backend/cuda/gemms/gemv.cu | 78 ++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/mlx/backend/cuda/gemms/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu index b62d6e6b8..0297b6f73 100644 --- a/mlx/backend/cuda/gemms/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -11,7 +11,6 @@ namespace mlx::core::cu { namespace cg = cooperative_groups; -static constexpr int n_per_thread = 4; static constexpr int rows_per_block = 8; template @@ -74,10 +73,27 @@ __global__ void gemv_batched( } bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { - return K % (WARP_SIZE * n_per_thread) == 0 && + bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0; + return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); } +template +void dispatch_n_per_thread(int n_per_thread, F&& f) { + switch (n_per_thread) { + case 2: + f(std::integral_constant{}); + break; + case 4: + f(std::integral_constant{}); + break; + case 8: + f(std::integral_constant{}); + break; + } +} + + void gemv( const array& a, const array& b, @@ -114,33 +130,39 @@ void gemv( rows = M; } uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; - if (batch_count == 1) { - auto kernel = gemv_single; - encoder.add_kernel_node( - kernel, - num_blocks_x, - block_dims, - mat, - vec, - out.data(), - rows, - cols); - } else { - auto kernel = gemv_batched; - encoder.add_kernel_node( - kernel, - dim3{num_blocks_x, batch_count}, - block_dims, - mat, - vec, - out.data(), - rows, - cols, - const_param(batch_shape), - mat_strides, - vec_strides, - batch_shape.size()); + int n_per_t = 16 / sizeof(DataType); + while (K % (n_per_t * WARP_SIZE) != 0) { + n_per_t /= 2; } + dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { + if (batch_count == 1) { + auto kernel = gemv_single; + encoder.add_kernel_node( + kernel, + num_blocks_x, + block_dims, + mat, + vec, + out.data(), + rows, + cols); + } else { + auto kernel = gemv_batched; + encoder.add_kernel_node( + kernel, + dim3{num_blocks_x, batch_count}, + block_dims, + mat, + vec, + out.data(), + rows, + cols, + const_param(batch_shape), + mat_strides, + vec_strides, + batch_shape.size()); + } + }); }); }