Fix oob reads in gemv kernel (#523)

This commit is contained in:
Jagrit Digani 2024-01-22 12:06:04 -08:00 committed by GitHub
parent ecb174ca9d
commit 6d3bee3364
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -121,8 +121,18 @@ struct GEMVKernel {
for(int tm = 0; tm < TM; tm++) {
// Load for the row
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * in_vec_size + col_idx];
}
}
// Accumulate results