diff --git a/mlx/backend/cuda/gemms/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu index 0297b6f73..9aaaa2541 100644 --- a/mlx/backend/cuda/gemms/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -74,26 +74,24 @@ __global__ void gemv_batched( bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0; - return is_multiple && - ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); + return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); } template void dispatch_n_per_thread(int n_per_thread, F&& f) { switch (n_per_thread) { + case 1: + f(std::integral_constant{}); + break; case 2: f(std::integral_constant{}); break; case 4: f(std::integral_constant{}); break; - case 8: - f(std::integral_constant{}); - break; } } - void gemv( const array& a, const array& b, @@ -130,9 +128,9 @@ void gemv( rows = M; } uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; - int n_per_t = 16 / sizeof(DataType); + int n_per_t = 4; while (K % (n_per_t * WARP_SIZE) != 0) { - n_per_t /= 2; + n_per_t >>= 1; } dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { if (batch_count == 1) {