mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix oob reads in gemv kernel (#523)
This commit is contained in:
parent
ecb174ca9d
commit
6d3bee3364
@ -121,8 +121,18 @@ struct GEMVKernel {
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
|
||||
// Load for the row
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
if(bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
inter[tn] = mat[tm * in_vec_size + col_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
|
Loading…
Reference in New Issue
Block a user