mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
longrope (#886)
This commit is contained in:
parent
8bf397e450
commit
bfc1f2763b
@ -33,9 +33,9 @@ class ModelArgs(BaseModelArgs):
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] not in ["su", "linear"]:
|
||||
if self.rope_scaling["type"] not in ["longrope", "su", "linear"]:
|
||||
print(
|
||||
"[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false."
|
||||
"[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false."
|
||||
)
|
||||
self.rope_scaling = None
|
||||
|
||||
@ -58,7 +58,7 @@ class Attention(nn.Module):
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = 1.0
|
||||
if args.rope_scaling and args.rope_scaling["type"] == "su":
|
||||
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
||||
self.rope = SuScaledRotaryEmbedding(
|
||||
head_dim,
|
||||
traditional=False,
|
||||
|
Loading…
Reference in New Issue
Block a user