Add memory_efficient_threshold kwarg to sdpa kernel (#1319)

Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
This commit is contained in:
Brian Keene
2024-08-12 15:57:09 -04:00
committed by GitHub
parent 9231617eb3
commit 19fb69e2ed
4 changed files with 13 additions and 4 deletions

View File

@@ -86,7 +86,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
o_mlx = mx.fast.scaled_dot_product_attention(
q_mlx, k_mlx, v_mlx, scale=scale
q_mlx, k_mlx, v_mlx, scale=scale, memory_efficient_threshold=2
)
self.assertListEqual(list(reference.shape), list(o_mlx.shape))