diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index f509f1b1c..e52b6f23a 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -21,6 +21,8 @@ uint3 lid [[thread_position_in_threadgroup]]); #define instantiate_attn_shapes_helper(iname, itype) \ + instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \ + instantiate_attn(iname, itype, 32, 16, 80, 4, 1) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \ instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \ diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 8599afbed..5aac270f8 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -30,8 +30,8 @@ void sdpa_full_self_attention_metal( int wn = 1; int bq = 32; - int bk = 16; - int bd = 64; + int bk = 32; + int bd = q.shape(-1); std::ostringstream kname; kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk diff --git a/mlx/fast.cpp b/mlx/fast.cpp index d27120cac..c90b5f5dd 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -644,7 +644,7 @@ array scaled_dot_product_attention( const bool sdpa_vector_supported_head_dim = query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128; const bool sdpa_full_supported_head_dim = - query_head_dim == 64 || query_head_dim == 128; + query_head_dim == 64 || query_head_dim == 80; const bool supports_sdpa_full = query_sequence_length >= threshold && !mask.has_value() && sdpa_full_supported_head_dim &&