Add support for InternLM-2.5 (#871)

* fix internlm-2

* formatting

* add dynamic ntk rope

* formatting

* move dynamic scaling rope to intermlm2.py

* add default max_position_embeddings
This commit is contained in:
Prince Canuma 2024-07-18 01:38:22 +02:00 committed by GitHub
parent 561dcf5643
commit 3d365b612a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,6 +17,7 @@ class ModelArgs(BaseModelArgs):
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
bias: bool = True bias: bool = True
max_position_embeddings: int = 32768
num_key_value_heads: int = None num_key_value_heads: int = None
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
@ -32,8 +33,50 @@ class ModelArgs(BaseModelArgs):
if not all(key in self.rope_scaling for key in required_keys): if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}") raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear": if self.rope_scaling["type"] not in ["linear", "dynamic"]:
raise ValueError("rope_scaling 'type' currently only supports 'linear'") raise ValueError(
"rope_scaling 'type' currently only supports 'linear' or 'dynamic"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.max_position_embeddings = max_position_embeddings
self.original_base = base
self.dims = dims
self.traditional = traditional
self.scale = scale
def extra_repr(self):
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
if seq_len > self.max_position_embeddings:
base = self.original_base * (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
else:
base = self.original_base
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=base,
scale=self.scale,
offset=offset,
)
class Attention(nn.Module): class Attention(nn.Module):
@ -56,10 +99,12 @@ class Attention(nn.Module):
rope_scale = ( rope_scale = (
1 / args.rope_scaling["factor"] 1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1 else 2.0
) )
self.rope = nn.RoPE(
self.rope = DynamicNTKScalingRoPE(
head_dim, head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional, traditional=args.rope_traditional,
base=args.rope_theta, base=args.rope_theta,
scale=rope_scale, scale=rope_scale,
@ -185,6 +230,10 @@ class Model(nn.Module):
out = self.output(out) out = self.output(out)
return out return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers