route more to gemv

This commit is contained in:
Awni Hannun
2025-07-27 15:30:47 -07:00
parent 1588659062
commit 2c7fe67d56

View File

@@ -11,7 +11,6 @@ namespace mlx::core::cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
static constexpr int n_per_thread = 4;
static constexpr int rows_per_block = 8; static constexpr int rows_per_block = 8;
template <typename T, int rows_per_block, int n_per_thread> template <typename T, int rows_per_block, int n_per_thread>
@@ -74,10 +73,27 @@ __global__ void gemv_batched(
} }
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
return K % (WARP_SIZE * n_per_thread) == 0 && bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0;
return is_multiple &&
((M == 1 && b_transposed) || (N == 1 && !a_transposed)); ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
} }
template <typename F>
void dispatch_n_per_thread(int n_per_thread, F&& f) {
switch (n_per_thread) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
void gemv( void gemv(
const array& a, const array& a,
const array& b, const array& b,
@@ -114,8 +130,13 @@ void gemv(
rows = M; rows = M;
} }
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
int n_per_t = 16 / sizeof(DataType);
while (K % (n_per_t * WARP_SIZE) != 0) {
n_per_t /= 2;
}
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) { if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>; auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks_x, num_blocks_x,
@@ -126,7 +147,7 @@ void gemv(
rows, rows,
cols); cols);
} else { } else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>; auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
dim3{num_blocks_x, batch_count}, dim3{num_blocks_x, batch_count},
@@ -142,6 +163,7 @@ void gemv(
batch_shape.size()); batch_shape.size());
} }
}); });
});
} }
} // namespace mlx::core::cu } // namespace mlx::core::cu