Ensure no oob read in gemv_masked

This commit is contained in:
Angelos Katharopoulos
2025-08-17 00:40:32 -07:00
parent 1ba18ff7d9
commit 753081052e

View File

@@ -262,10 +262,10 @@ struct GEMVKernel {
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
if (leftover > 0) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
@@ -295,6 +295,7 @@ struct GEMVKernel {
}
}
}
}
// Apply out scale
if (has_mul_output_mask) {
@@ -544,10 +545,10 @@ struct GEMVTKernel {
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
if (leftover > 0) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
@@ -573,6 +574,7 @@ struct GEMVTKernel {
}
}
}
}
// Apply out scale
if (has_mul_output_mask) {