chore: remove default rope_scaling dict and use get to access type and factor to avoid key error

This commit is contained in:
anchen
2024-01-05 23:39:49 -08:00
parent c270881a85
commit fa6ff4e517

View File

@@ -3,7 +3,7 @@
import glob
import inspect
import json
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
@@ -26,20 +26,19 @@ class ModelArgs:
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = field(default_factory=dict)
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@classmethod
def from_dict(cls, params):
@@ -83,9 +82,17 @@ class Attention(nn.Module):
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = 1 / args.rope_scaling['factor'] if args.rope_scaling is not None and args.rope_scaling['type'] == 'linear' else 1
rope_scale = (
1 / args.rope_scaling.get("factor", 1.0)
if args.rope_scaling is not None
and args.rope_scaling.get("type") == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim, traditional=args.rope_traditional, base=args.rope_theta, scale= rope_scale
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(