diff --git a/mlx/backend/metal/kernels/gemv_masked.h b/mlx/backend/metal/kernels/gemv_masked.h index 75bc7354c..96b0c2821 100644 --- a/mlx/backend/metal/kernels/gemv_masked.h +++ b/mlx/backend/metal/kernels/gemv_masked.h @@ -262,36 +262,37 @@ struct GEMVKernel { vec_mask_offset += vec_mask_step; } - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - load_safe(in_vec, v_coeff, bn, in_size); - - // Apply scale - if (has_mul_operand_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - v_coeff[tn] *= block_scale; + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } - } - // Per thread work loop - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + load_safe(in_vec, v_coeff, bn, in_size); - // Accumulate results + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } } } } @@ -544,31 +545,32 @@ struct GEMVTKernel { vec_mask_offset += vec_mask_step; } - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); - + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; if (has_mul_operand_mask) { - v_coeff[tm] *= block_scale; + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } } } }