mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-23 22:18:13 +08:00
full row mask in sdpa consistently gives nan (#2406)
This commit is contained in:
@@ -708,7 +708,10 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
if (mask.dtype() == bool_) {
|
||||
scores = where(
|
||||
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
|
||||
mask,
|
||||
scores,
|
||||
array(-std::numeric_limits<float>::infinity(), scores.dtype()),
|
||||
s);
|
||||
} else {
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
|
Reference in New Issue
Block a user