mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
fix per-example mask + docs in sdpa (#1574)
This commit is contained in:
@@ -140,12 +140,23 @@ void init_fast(nb::module_& parent_module) {
|
||||
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
|
||||
and ``v`` inputs should not be pre-tiled to match ``q``.
|
||||
|
||||
In the following the dimensions are given by:
|
||||
|
||||
* ``B``: The batch size.
|
||||
* ``N_q``: The number of query heads.
|
||||
* ``N_kv``: The number of key and value heads.
|
||||
* ``T_q``: The number of queries per example.
|
||||
* ``T_kv``: The number of keys and values per example.
|
||||
* ``D``: The per-head dimension.
|
||||
|
||||
Args:
|
||||
q (array): Input query array.
|
||||
k (array): Input keys array.
|
||||
v (array): Input values array.
|
||||
q (array): Queries with shape ``[B, N_q, T_q, D]``.
|
||||
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
||||
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): An additive mask to apply to the query-key scores.
|
||||
mask (array, optional): An additive mask to apply to the query-key
|
||||
scores. The mask can have at most 4 dimensions and must be
|
||||
broadcast-compatible with the shape ``[B, N, T_q, T_kv]``.
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
Reference in New Issue
Block a user