This commit is contained in:
JosefAlbers 2024-07-12 23:19:11 +09:00 committed by GitHub
parent 8bf397e450
commit bfc1f2763b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,9 +33,9 @@ class ModelArgs(BaseModelArgs):
if not all(key in self.rope_scaling for key in required_keys): if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}") raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] not in ["su", "linear"]: if self.rope_scaling["type"] not in ["longrope", "su", "linear"]:
print( print(
"[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false." "[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false."
) )
self.rope_scaling = None self.rope_scaling = None
@ -58,7 +58,7 @@ class Attention(nn.Module):
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = 1.0 rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "su": if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding( self.rope = SuScaledRotaryEmbedding(
head_dim, head_dim,
traditional=False, traditional=False,