mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
route more sizes to custom gemv
This commit is contained in:
@@ -74,26 +74,24 @@ __global__ void gemv_batched(
|
||||
|
||||
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
||||
bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0;
|
||||
return is_multiple &&
|
||||
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||
return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_n_per_thread(int n_per_thread, F&& f) {
|
||||
switch (n_per_thread) {
|
||||
case 1:
|
||||
f(std::integral_constant<int, 1>{});
|
||||
break;
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 4:
|
||||
f(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
case 8:
|
||||
f(std::integral_constant<int, 8>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gemv(
|
||||
const array& a,
|
||||
const array& b,
|
||||
@@ -130,9 +128,9 @@ void gemv(
|
||||
rows = M;
|
||||
}
|
||||
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
||||
int n_per_t = 16 / sizeof(DataType);
|
||||
int n_per_t = 4;
|
||||
while (K % (n_per_t * WARP_SIZE) != 0) {
|
||||
n_per_t /= 2;
|
||||
n_per_t >>= 1;
|
||||
}
|
||||
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
|
||||
if (batch_count == 1) {
|
||||
|
||||
Reference in New Issue
Block a user