fix per-example mask + docs in sdpa (#1574)

This commit is contained in:
Awni Hannun
2024-11-08 11:51:15 -08:00
committed by GitHub
parent 9f0d5c12fc
commit 91c0277356
3 changed files with 42 additions and 11 deletions

View File

@@ -140,12 +140,23 @@ void init_fast(nb::module_& parent_module) {
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``.
In the following the dimensions are given by:
* ``B``: The batch size.
* ``N_q``: The number of query heads.
* ``N_kv``: The number of key and value heads.
* ``T_q``: The number of queries per example.
* ``T_kv``: The number of keys and values per example.
* ``D``: The per-head dimension.
Args:
q (array): Input query array.
k (array): Input keys array.
v (array): Input values array.
q (array): Queries with shape ``[B, N_q, T_q, D]``.
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
mask (array, optional): An additive mask to apply to the query-key
scores. The mask can have at most 4 dimensions and must be
broadcast-compatible with the shape ``[B, N, T_q, T_kv]``.
Returns:
array: The output array.
)pbdoc");