Use fast rope (#945)

* use fast rope

* fix llama

* use fast rope for llama3.1

* requires unreleased mlx

* fix su

* fix deepseek v2

* only one of base or freqs

* nit

* fix

* hard code freqs
This commit is contained in:
Awni Hannun
2024-08-23 13:18:51 -07:00
committed by GitHub
parent 58591a1b41
commit 6731254e76
7 changed files with 65 additions and 137 deletions

View File

@@ -59,19 +59,17 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding(
head_dim,
traditional=False,
base=args.rope_theta,
scale=rope_scale,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]