mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Ensure no oob read in gemv_masked (#2508)
This commit is contained in:
parent
73f22d6226
commit
1df9887998
@ -262,10 +262,10 @@ struct GEMVKernel {
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
if (leftover > 0) {
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
@ -295,6 +295,7 @@ struct GEMVKernel {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
@ -544,10 +545,10 @@ struct GEMVTKernel {
|
||||
vec_mask_offset += vec_mask_step;
|
||||
}
|
||||
|
||||
if (leftover > 0 &&
|
||||
(!has_operand_mask ||
|
||||
if (leftover > 0) {
|
||||
if (!has_operand_mask ||
|
||||
(bool(mat_mask[mat_mask_offset]) &&
|
||||
bool(vec_mask[vec_mask_offset])))) {
|
||||
bool(vec_mask[vec_mask_offset]))) {
|
||||
T block_scale{1};
|
||||
if (has_mul_operand_mask) {
|
||||
block_scale =
|
||||
@ -573,6 +574,7 @@ struct GEMVTKernel {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply out scale
|
||||
if (has_mul_output_mask) {
|
||||
|
Loading…
Reference in New Issue
Block a user