Allow different value dimensions in sdpa_vector (#1811)

This commit is contained in:
Angelos Katharopoulos
2025-01-31 20:58:59 -08:00
committed by GitHub
parent b7c9f1d38f
commit f5cc1eea72
6 changed files with 127 additions and 72 deletions

View File

@@ -262,6 +262,23 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
@unittest.skip("Different head and value dims is not enabled")
def test_fast_sdpa_vector_value_dims(self):
D = 192
V = 128
Nq = 4
Nkv = 1
scale = 1.0
mx.random.seed(0)
for L in [43, 128, 237, 8192]:
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, V))
ref = mlx_primitives_sdpa(q, k, v, scale)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
if __name__ == "__main__":
unittest.main(failfast=True)