full row mask in sdpa consistently gives nan (#2406)

This commit is contained in:
Awni Hannun
2025-07-23 16:37:03 -07:00
committed by GitHub
parent 0f5ce173da
commit e1840853ce
2 changed files with 16 additions and 1 deletions

View File

@@ -708,7 +708,10 @@ array scaled_dot_product_attention(
}
if (mask.dtype() == bool_) {
scores = where(
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
mask,
scores,
array(-std::numeric_limits<float>::infinity(), scores.dtype()),
s);
} else {
scores = add(scores, mask, s);
}