fix matrix sdpa

This commit is contained in:
Awni Hannun
2025-09-09 12:52:56 -07:00
parent 0fe25eb588
commit 836f019d3b
2 changed files with 4 additions and 5 deletions

View File

@@ -731,7 +731,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
for T_kv in [128, 4096]:
for T_q in [1]: # , 128]:
for T_q in [1, 128]:
for N_kv in [2, 8]:
q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D))