mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
full row mask in sdpa consistently gives nan (#2406)
This commit is contained in:
@@ -708,7 +708,10 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
if (mask.dtype() == bool_) {
|
||||
scores = where(
|
||||
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
|
||||
mask,
|
||||
scores,
|
||||
array(-std::numeric_limits<float>::infinity(), scores.dtype()),
|
||||
s);
|
||||
} else {
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user