fix 2 pass

This commit is contained in:
Awni Hannun
2025-09-09 12:12:09 -07:00
parent 3ca3ab9dcd
commit 0fe25eb588
2 changed files with 13 additions and 12 deletions

View File

@@ -730,18 +730,19 @@ class TestSDPA(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
for T_q in [1, 128]:
for N_kv in [2, 8]:
q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
sinks = 10 * mx.random.normal(shape=(N_q,))
for T_kv in [128, 4096]:
for T_q in [1]: # , 128]:
for N_kv in [2, 8]:
q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
sinks = 10 * mx.random.normal(shape=(N_q,))
expected = mlx_ref_attn(q, k, v, scale, sinks=sinks)
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, sinks=sinks
)
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
expected = mlx_ref_attn(q, k, v, scale, sinks=sinks)
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, sinks=sinks
)
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
if __name__ == "__main__":