mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	sdpa specialization for head dim 256 (#2007)
This commit is contained in:
		| @@ -32,9 +32,11 @@ using namespace metal; | ||||
|   instantiate_sdpa_vector(type, 64, 64)          \ | ||||
|   instantiate_sdpa_vector(type, 96, 96)          \ | ||||
|   instantiate_sdpa_vector(type, 128, 128)        \ | ||||
|   instantiate_sdpa_vector(type, 256, 256)        \ | ||||
|   instantiate_sdpa_vector_aggregation(type, 64)  \ | ||||
|   instantiate_sdpa_vector_aggregation(type, 96)  \ | ||||
|   instantiate_sdpa_vector_aggregation(type, 128) | ||||
|   instantiate_sdpa_vector_aggregation(type, 128) \ | ||||
|   instantiate_sdpa_vector_aggregation(type, 256) | ||||
|  | ||||
| instantiate_sdpa_vector_heads(float) | ||||
| instantiate_sdpa_vector_heads(bfloat16_t) | ||||
|   | ||||
| @@ -720,7 +720,8 @@ array scaled_dot_product_attention( | ||||
|  | ||||
|   const bool sdpa_vector_supported_head_dim = | ||||
|       query_head_dim == value_head_dim && | ||||
|       (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); | ||||
|       (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || | ||||
|        query_head_dim == 256); | ||||
|   const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && | ||||
|       (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun