mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Ensure no oob read in gemv_masked (#2508)
This commit is contained in:
parent
73f22d6226
commit
1df9887998
@ -262,36 +262,37 @@ struct GEMVKernel {
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
|
||||
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] *= block_scale;
|
||||
if (leftover > 0) {
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
||||
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
|
||||
|
||||
// Accumulate results
|
||||
// Apply scale
|
||||
if (has_mul_operand_mask) {
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] *= block_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -544,31 +545,32 @@ struct GEMVTKernel {
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
||||
|
||||
if (leftover > 0) {
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
v_coeff[tm] *= block_scale;
|
||||
block_scale =
|
||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
||||
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
if (has_mul_operand_mask) {
|
||||
v_coeff[tm] *= block_scale;
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user