fix gemv regression (#2445)

This commit is contained in:
Awni Hannun
2025-07-30 14:23:01 -07:00
committed by GitHub
parent b405591249
commit d32519c8ee
3 changed files with 36 additions and 9 deletions

View File

@@ -43,10 +43,18 @@ struct alignas(sizeof(T) * N) AlignedVector {
};
template <int N, typename T>
inline __device__ bool is_aligned(T* x) {
inline __host__ __device__ bool is_aligned(T* x) {
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> unsafe_load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
@@ -101,6 +109,13 @@ inline __device__ AlignedVector<T, N> load_vector(
}
}
template <int N, typename T>
inline __device__ void
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {

View File

@@ -27,8 +27,9 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
auto local_mat =
unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum +=
@@ -127,9 +128,13 @@ void gemv(
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
int n_per_t = 4;
while (K % (n_per_t * WARP_SIZE) != 0) {
n_per_t >>= 1;
int n_per_t;
if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {
n_per_t = 4;
} else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {
n_per_t = 2;
} else {
n_per_t = 1;
}
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) {