fix per-example mask + docs in sdpa (#1574)

This commit is contained in:
Awni Hannun
2024-11-08 11:51:15 -08:00
committed by GitHub
parent 9f0d5c12fc
commit 91c0277356
3 changed files with 42 additions and 11 deletions

View File

@@ -140,12 +140,23 @@ void init_fast(nb::module_& parent_module) {
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``.
In the following the dimensions are given by:
* ``B``: The batch size.
* ``N_q``: The number of query heads.
* ``N_kv``: The number of key and value heads.
* ``T_q``: The number of queries per example.
* ``T_kv``: The number of keys and values per example.
* ``D``: The per-head dimension.
Args:
q (array): Input query array.
k (array): Input keys array.
v (array): Input values array.
q (array): Queries with shape ``[B, N_q, T_q, D]``.
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
mask (array, optional): An additive mask to apply to the query-key
scores. The mask can have at most 4 dimensions and must be
broadcast-compatible with the shape ``[B, N, T_q, T_kv]``.
Returns:
array: The output array.
)pbdoc");

View File

@@ -7,14 +7,16 @@ import numpy as np
# SDPA for MHA (n_heads == n_kv_heads)
def mlx_primitives_sdpa(q, k, v, scale):
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
p = (q * scale) @ k.transpose(0, 1, 3, 2)
if mask is not None:
p += mask
scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
return scores @ v
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
def mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=None):
n_repeats = q.shape[1] // k.shape[1]
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
@@ -28,7 +30,7 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
k, v = map(repeat, (k, v))
return mlx_primitives_sdpa(q, k, v, scale)
return mlx_primitives_sdpa(q, k, v, scale, mask=mask)
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
@@ -176,6 +178,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
# Test with per-example mask
q = mx.random.normal(shape=(2, 8, 4, 32))
k = mx.random.normal(shape=(2, 2, 8, 32))
v = mx.random.normal(shape=(2, 2, 8, 32))
mask = 10 * mx.random.normal(shape=(2, 1, 4, 8))
y = mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=mask)
y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
if __name__ == "__main__":
unittest.main(failfast=True)