mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
fix per-example mask + docs in sdpa (#1574)
This commit is contained in:
@@ -7,14 +7,16 @@ import numpy as np
|
||||
|
||||
|
||||
# SDPA for MHA (n_heads == n_kv_heads)
|
||||
def mlx_primitives_sdpa(q, k, v, scale):
|
||||
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
|
||||
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
p += mask
|
||||
scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
|
||||
return scores @ v
|
||||
|
||||
|
||||
# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
|
||||
def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||
def mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=None):
|
||||
n_repeats = q.shape[1] // k.shape[1]
|
||||
|
||||
# borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
|
||||
@@ -28,7 +30,7 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||
|
||||
k, v = map(repeat, (k, v))
|
||||
|
||||
return mlx_primitives_sdpa(q, k, v, scale)
|
||||
return mlx_primitives_sdpa(q, k, v, scale, mask=mask)
|
||||
|
||||
|
||||
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
@@ -176,6 +178,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale)
|
||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||
|
||||
# Test with per-example mask
|
||||
q = mx.random.normal(shape=(2, 8, 4, 32))
|
||||
k = mx.random.normal(shape=(2, 2, 8, 32))
|
||||
v = mx.random.normal(shape=(2, 2, 8, 32))
|
||||
mask = 10 * mx.random.normal(shape=(2, 1, 4, 8))
|
||||
y = mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=mask)
|
||||
y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Reference in New Issue
Block a user