mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:10:15 +08:00
Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence length. Otherwise, utilizes aggregate MLX primitives for best latency.
This commit is contained in:
@@ -86,7 +86,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
|
||||
o_mlx = mx.fast.scaled_dot_product_attention(
|
||||
q_mlx, k_mlx, v_mlx, scale=scale
|
||||
q_mlx, k_mlx, v_mlx, scale=scale, memory_efficient_threshold=2
|
||||
)
|
||||
|
||||
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
|
||||
|
Reference in New Issue
Block a user