fix 2 pass

This commit is contained in:
Awni Hannun
2025-09-09 12:12:09 -07:00
parent 3ca3ab9dcd
commit 0fe25eb588
2 changed files with 13 additions and 12 deletions

View File

@@ -263,7 +263,7 @@ template <typename T, int D, int V = D>
U max_score = -INFINITY; U max_score = -INFINITY;
U sum_exp_score = 0; U sum_exp_score = 0;
if (has_sinks && simd_gid == 0) { if (has_sinks && block_idx == 0 && simd_gid == 0) {
int q_head_idx = q_batch_head_idx % num_q_heads; int q_head_idx = q_batch_head_idx % num_q_heads;
max_score = static_cast<U>(sinks[q_head_idx]); max_score = static_cast<U>(sinks[q_head_idx]);
sum_exp_score = 1; sum_exp_score = 1;

View File

@@ -730,7 +730,8 @@ class TestSDPA(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks) mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
for T_q in [1, 128]: for T_kv in [128, 4096]:
for T_q in [1]: # , 128]:
for N_kv in [2, 8]: for N_kv in [2, 8]:
q = mx.random.normal(shape=(B, N_q, T_q, D)) q = mx.random.normal(shape=(B, N_q, T_q, D))
k = mx.random.normal(shape=(B, N_kv, T_kv, D)) k = mx.random.normal(shape=(B, N_kv, T_kv, D))