Allow boolean mask in sdpa (#1753)

* allow boolean mask in sdpa

* more permissive donation in ternary
This commit is contained in:
Awni Hannun
2025-01-06 16:57:07 -08:00
committed by GitHub
parent 25b3a3e541
commit d5ec172c95
4 changed files with 28 additions and 5 deletions

View File

@@ -164,9 +164,11 @@ void init_fast(nb::module_& parent_module) {
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key
scores. The mask can have at most 4 dimensions and must be
broadcast-compatible with the shape ``[B, N, T_q, T_kv]``.
mask (array, optional): A boolean or additive mask to apply to the
query-key scores. The mask can have at most 4 dimensions and must
be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an
additive mask is given its type must promote to the promoted
type of ``q``, ``k``, and ``v``.
Returns:
array: The output array.
)pbdoc");