mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
sdpa specialization for head dim 256 (#2007)
This commit is contained in:
@@ -32,9 +32,11 @@ using namespace metal;
|
||||
instantiate_sdpa_vector(type, 64, 64) \
|
||||
instantiate_sdpa_vector(type, 96, 96) \
|
||||
instantiate_sdpa_vector(type, 128, 128) \
|
||||
instantiate_sdpa_vector(type, 256, 256) \
|
||||
instantiate_sdpa_vector_aggregation(type, 64) \
|
||||
instantiate_sdpa_vector_aggregation(type, 96) \
|
||||
instantiate_sdpa_vector_aggregation(type, 128)
|
||||
instantiate_sdpa_vector_aggregation(type, 128) \
|
||||
instantiate_sdpa_vector_aggregation(type, 256)
|
||||
|
||||
instantiate_sdpa_vector_heads(float)
|
||||
instantiate_sdpa_vector_heads(bfloat16_t)
|
||||
|
Reference in New Issue
Block a user