mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Ensure no oob read in gemv_masked (#2508)
This commit is contained in:
parent
73f22d6226
commit
1df9887998
@ -262,36 +262,37 @@ struct GEMVKernel {
|
|||||||
vec_mask_offset += vec_mask_step;
|
vec_mask_offset += vec_mask_step;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (leftover > 0 &&
|
if (leftover > 0) {
|
||||||
(!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(bool(mat_mask[mat_mask_offset]) &&
|
(bool(mat_mask[mat_mask_offset]) &&
|
||||||
bool(vec_mask[vec_mask_offset])))) {
|
bool(vec_mask[vec_mask_offset]))) {
|
||||||
T block_scale{1};
|
T block_scale{1};
|
||||||
if (has_mul_operand_mask) {
|
if (has_mul_operand_mask) {
|
||||||
block_scale =
|
block_scale =
|
||||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||||
}
|
|
||||||
|
|
||||||
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
|
|
||||||
|
|
||||||
// Apply scale
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
|
||||||
v_coeff[tn] *= block_scale;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Per thread work loop
|
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int tm = 0; tm < TM; tm++) {
|
|
||||||
// Load for the row
|
|
||||||
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
|
||||||
|
|
||||||
// Accumulate results
|
// Apply scale
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
v_coeff[tn] *= block_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per thread work loop
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
for (int tm = 0; tm < TM; tm++) {
|
||||||
result[tm] += inter[tn] * v_coeff[tn];
|
// Load for the row
|
||||||
|
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
||||||
|
|
||||||
|
// Accumulate results
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tm] += inter[tn] * v_coeff[tn];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -544,31 +545,32 @@ struct GEMVTKernel {
|
|||||||
vec_mask_offset += vec_mask_step;
|
vec_mask_offset += vec_mask_step;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (leftover > 0 &&
|
if (leftover > 0) {
|
||||||
(!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(bool(mat_mask[mat_mask_offset]) &&
|
(bool(mat_mask[mat_mask_offset]) &&
|
||||||
bool(vec_mask[vec_mask_offset])))) {
|
bool(vec_mask[vec_mask_offset]))) {
|
||||||
T block_scale{1};
|
T block_scale{1};
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
block_scale =
|
|
||||||
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
|
||||||
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
|
||||||
|
|
||||||
if (has_mul_operand_mask) {
|
if (has_mul_operand_mask) {
|
||||||
v_coeff[tm] *= block_scale;
|
block_scale =
|
||||||
|
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
||||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
||||||
}
|
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
if (has_mul_operand_mask) {
|
||||||
for (int tn = 0; tn < TN; tn++) {
|
v_coeff[tm] *= block_scale;
|
||||||
result[tn] += v_coeff[tm] * inter[tn];
|
}
|
||||||
|
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||||
|
}
|
||||||
|
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int tn = 0; tn < TN; tn++) {
|
||||||
|
result[tn] += v_coeff[tm] * inter[tn];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user