diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index ea80396df..4abef4c49 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -32,9 +32,11 @@ using namespace metal; instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 128, 128) \ + instantiate_sdpa_vector(type, 256, 256) \ instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 96) \ - instantiate_sdpa_vector_aggregation(type, 128) + instantiate_sdpa_vector_aggregation(type, 128) \ + instantiate_sdpa_vector_aggregation(type, 256) instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index bcc5ccbd3..ea591d675 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -720,7 +720,8 @@ array scaled_dot_product_attention( const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || + query_head_dim == 256); const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);