mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix matrix sdpa
This commit is contained in:
@@ -731,7 +731,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)
|
||||
|
||||
for T_kv in [128, 4096]:
|
||||
for T_q in [1]: # , 128]:
|
||||
for T_q in [1, 128]:
|
||||
for N_kv in [2, 8]:
|
||||
q = mx.random.normal(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.normal(shape=(B, N_kv, T_kv, D))
|
||||
|
||||
Reference in New Issue
Block a user