diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index d85d72d9e..e0325a1f1 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -121,8 +121,18 @@ struct GEMVKernel { for(int tm = 0; tm < TM; tm++) { // Load for the row - for(int tn = 0; tn < TN; tn++) { - inter[tn] = mat[tm * in_vec_size + bn + tn]; + if(bn + TN <= in_vec_size) { + #pragma clang loop unroll(full) + for(int tn = 0; tn < TN; tn++) { + inter[tn] = mat[tm * in_vec_size + bn + tn]; + } + + } else { // Edgecase + #pragma clang loop unroll(full) + for(int tn = 0; tn < TN; tn++) { + int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1); + inter[tn] = mat[tm * in_vec_size + col_idx]; + } } // Accumulate results