mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix 2 pass
This commit is contained in:
@@ -263,7 +263,7 @@ template <typename T, int D, int V = D>
|
|||||||
|
|
||||||
U max_score = -INFINITY;
|
U max_score = -INFINITY;
|
||||||
U sum_exp_score = 0;
|
U sum_exp_score = 0;
|
||||||
if (has_sinks && simd_gid == 0) {
|
if (has_sinks && block_idx == 0 && simd_gid == 0) {
|
||||||
int q_head_idx = q_batch_head_idx % num_q_heads;
|
int q_head_idx = q_batch_head_idx % num_q_heads;
|
||||||
max_score = static_cast<U>(sinks[q_head_idx]);
|
max_score = static_cast<U>(sinks[q_head_idx]);
|
||||||
sum_exp_score = 1;
|
sum_exp_score = 1;
|
||||||
|
|||||||
@@ -730,18 +730,19 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
||||||
|
|
||||||
for T_q in [1, 128]:
|
for T_kv in [128, 4096]:
|
||||||
for N_kv in [2, 8]:
|
for T_q in [1]: # , 128]:
|
||||||
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
for N_kv in [2, 8]:
|
||||||
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
||||||
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||||
sinks = 10 * mx.random.normal(shape=(N_q,))
|
v = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||||
|
sinks = 10 * mx.random.normal(shape=(N_q,))
|
||||||
|
|
||||||
expected = mlx_ref_attn(q, k, v, scale, sinks=sinks)
|
expected = mlx_ref_attn(q, k, v, scale, sinks=sinks)
|
||||||
out = mx.fast.scaled_dot_product_attention(
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
q, k, v, scale=scale, sinks=sinks
|
q, k, v, scale=scale, sinks=sinks
|
||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user