From 6d3bee336443102c6960b1ce9723dc78dcf38a2b Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Mon, 22 Jan 2024 12:06:04 -0800 Subject: [PATCH] Fix oob reads in gemv kernel (#523) --- mlx/backend/metal/kernels/gemv.metal | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index d85d72d9e..e0325a1f1 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -121,8 +121,18 @@ struct GEMVKernel { for(int tm = 0; tm < TM; tm++) { // Load for the row - for(int tn = 0; tn < TN; tn++) { - inter[tn] = mat[tm * in_vec_size + bn + tn]; + if(bn + TN <= in_vec_size) { + #pragma clang loop unroll(full) + for(int tn = 0; tn < TN; tn++) { + inter[tn] = mat[tm * in_vec_size + bn + tn]; + } + + } else { // Edgecase + #pragma clang loop unroll(full) + for(int tn = 0; tn < TN; tn++) { + int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1); + inter[tn] = mat[tm * in_vec_size + col_idx]; + } } // Accumulate results