mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
[CUDA] More sizes for gemv (#2429)
* route more to gemv * route more sizes to custom gemv
This commit is contained in:
@@ -11,7 +11,6 @@ namespace mlx::core::cu {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
static constexpr int n_per_thread = 4;
|
|
||||||
static constexpr int rows_per_block = 8;
|
static constexpr int rows_per_block = 8;
|
||||||
|
|
||||||
template <typename T, int rows_per_block, int n_per_thread>
|
template <typename T, int rows_per_block, int n_per_thread>
|
||||||
@@ -74,8 +73,23 @@ __global__ void gemv_batched(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
||||||
return K % (WARP_SIZE * n_per_thread) == 0 &&
|
bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0;
|
||||||
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_n_per_thread(int n_per_thread, F&& f) {
|
||||||
|
switch (n_per_thread) {
|
||||||
|
case 1:
|
||||||
|
f(std::integral_constant<int, 1>{});
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
f(std::integral_constant<int, 2>{});
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
f(std::integral_constant<int, 4>{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void gemv(
|
void gemv(
|
||||||
@@ -114,8 +128,13 @@ void gemv(
|
|||||||
rows = M;
|
rows = M;
|
||||||
}
|
}
|
||||||
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
||||||
|
int n_per_t = 4;
|
||||||
|
while (K % (n_per_t * WARP_SIZE) != 0) {
|
||||||
|
n_per_t >>= 1;
|
||||||
|
}
|
||||||
|
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
|
||||||
if (batch_count == 1) {
|
if (batch_count == 1) {
|
||||||
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>;
|
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks_x,
|
num_blocks_x,
|
||||||
@@ -126,7 +145,7 @@ void gemv(
|
|||||||
rows,
|
rows,
|
||||||
cols);
|
cols);
|
||||||
} else {
|
} else {
|
||||||
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>;
|
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
dim3{num_blocks_x, batch_count},
|
dim3{num_blocks_x, batch_count},
|
||||||
@@ -142,6 +161,7 @@ void gemv(
|
|||||||
batch_shape.size());
|
batch_shape.size());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
Reference in New Issue
Block a user