faster rms norm (#2433)

This commit is contained in:
Awni Hannun
2025-07-29 13:12:00 -07:00
committed by GitHub
parent 970dbe8e25
commit ef631d63af
11 changed files with 210 additions and 112 deletions

View File

@@ -31,8 +31,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum += static_cast<float>(local_mat.val[j]) *
static_cast<float>(local_vec.val[j]);
sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
}
}
@@ -73,8 +73,7 @@ __global__ void gemv_batched(
}
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0;
return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
}
template <typename F>