Make attention faster for a some models (#574)

* make attention faster for a couple models

* remove unused generation flags

* add comment on lora

* include text files as well
This commit is contained in:
Awni Hannun
2024-03-14 21:35:54 -07:00
committed by GitHub
parent 3f3741d229
commit e4b19bb9e1
6 changed files with 35 additions and 56 deletions

View File

@@ -68,7 +68,6 @@ class RoPEAttention(nn.Module):
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])