mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix oob reads in gemv kernel (#523)
This commit is contained in:
parent
ecb174ca9d
commit
6d3bee3364
@ -121,8 +121,18 @@ struct GEMVKernel {
|
|||||||
for(int tm = 0; tm < TM; tm++) {
|
for(int tm = 0; tm < TM; tm++) {
|
||||||
|
|
||||||
// Load for the row
|
// Load for the row
|
||||||
for(int tn = 0; tn < TN; tn++) {
|
if(bn + TN <= in_vec_size) {
|
||||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||||
|
}
|
||||||
|
|
||||||
|
} else { // Edgecase
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for(int tn = 0; tn < TN; tn++) {
|
||||||
|
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||||
|
inter[tn] = mat[tm * in_vec_size + col_idx];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate results
|
// Accumulate results
|
||||||
|
Loading…
Reference in New Issue
Block a user