mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Allow boolean mask in sdpa (#1753)
* allow boolean mask in sdpa * more permissive donation in ternary
This commit is contained in:
@@ -659,7 +659,12 @@ array scaled_dot_product_attention(
|
||||
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
|
||||
}
|
||||
}
|
||||
scores = add(scores, mask, s);
|
||||
if (mask.dtype() == bool_) {
|
||||
scores = where(
|
||||
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
|
||||
} else {
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
}
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
auto out = matmul(scores, v, s);
|
||||
|
||||
Reference in New Issue
Block a user