Ensure no oob read in gemv_masked (#2508)

This commit is contained in:
Angelos Katharopoulos 2025-08-17 08:42:33 -07:00 committed by GitHub
parent 73f22d6226
commit 1df9887998
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -262,36 +262,37 @@ struct GEMVKernel {
vec_mask_offset += vec_mask_step; vec_mask_offset += vec_mask_step;
} }
if (leftover > 0 && if (leftover > 0) {
(!has_operand_mask || if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) && (bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) { bool(vec_mask[vec_mask_offset]))) {
T block_scale{1}; T block_scale{1};
if (has_mul_operand_mask) { if (has_mul_operand_mask) {
block_scale = block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
} }
}
// Per thread work loop load_safe<AccT>(in_vec, v_coeff, bn, in_size);
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
// Accumulate results // Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) { for (int tm = 0; tm < TM; tm++) {
result[tm] += inter[tn] * v_coeff[tn]; // Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
} }
} }
} }
@ -544,31 +545,32 @@ struct GEMVTKernel {
vec_mask_offset += vec_mask_step; vec_mask_offset += vec_mask_step;
} }
if (leftover > 0 && if (leftover > 0) {
(!has_operand_mask || if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) && (bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) { bool(vec_mask[vec_mask_offset]))) {
T block_scale{1}; T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
if (has_mul_operand_mask) { if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale; block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
} }
MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
for (int tn = 0; tn < TN; tn++) { v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
MLX_MTL_PRAGMA_UNROLL if (has_mul_operand_mask) {
for (int tn = 0; tn < TN; tn++) { v_coeff[tm] *= block_scale;
result[tn] += v_coeff[tm] * inter[tn]; }
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
} }
} }
} }