chore: remove hardcoded rope_scaling_factor

This commit is contained in:
Anchen 2023-12-22 18:03:20 +11:00 committed by Awni Hannun
parent e17e07002a
commit 6a62a8bca4

View File

@ -105,7 +105,9 @@ class Attention(nn.Module):
args.num_attention_heads * self.head_dim, args.hidden_size, bias=False
)
self.rope = LinearScalingRoPE(
self.head_dim, rope_scaling_factor=4.0, base=args.rope_theta
self.head_dim,
rope_scaling_factor=args.rope_scaling_factor,
base=args.rope_theta,
)
def __call__(