diff --git a/benchmarks/python/sdpa_bench.py b/benchmarks/python/sdpa_bench.py index c26c9ce04..23383475e 100644 --- a/benchmarks/python/sdpa_bench.py +++ b/benchmarks/python/sdpa_bench.py @@ -144,7 +144,7 @@ if __name__ == "__main__": transposes = (False,) # fmt: off - shapes = ( + shapes_64 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) ( 1, 32, 32, 64, 32, 32), ( 1, 64, 64, 64, 32, 32), @@ -162,9 +162,16 @@ if __name__ == "__main__": ( 1, 2048, 2048, 80, 32, 32), ( 1, 4096, 4096, 80, 32, 32), ) + + shapes_128 = ( + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) + ( 1, 1024, 1024, 128, 32, 32), + ( 1, 2048, 2048, 128, 32, 32), + ( 1, 4096, 4096, 128, 32, 32), + ) # fmt: on - shapes = shapes + shapes_80 + shapes = shapes_64 + shapes_80 + shapes_128 print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%") diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index d9e2ce2ca..99c27de25 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -21,7 +21,7 @@ uint3 lid [[thread_position_in_threadgroup]]); #define instantiate_attn_shapes_helper(iname, itype) \ - instantiate_attn(iname, itype, 32, 32, 128, 4, 1) \ + instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \ instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 8080c0f6f..eadf2a663 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -29,9 +29,9 @@ void sdpa_full_self_attention_metal( int wm = 4; int wn = 1; - int bq = 32; - int bk = 32; int bd = q.shape(-1); + int bq = 32; + int bk = bd < 128 ? 32 : 16; int B = q.shape(0); int H = q.shape(1); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 47f17745c..8cbf0d015 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -644,7 +644,7 @@ array scaled_dot_product_attention( const bool sdpa_vector_supported_head_dim = query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128; const bool sdpa_full_supported_head_dim = - query_head_dim == 64 || query_head_dim == 80; + query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128; const bool supports_sdpa_full = query_sequence_length >= threshold && !mask.has_value() && sdpa_full_supported_head_dim &&