mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add support for InternLM-2.5 (#871)
* fix internlm-2 * formatting * add dynamic ntk rope * formatting * move dynamic scaling rope to intermlm2.py * add default max_position_embeddings
This commit is contained in:
parent
561dcf5643
commit
3d365b612a
@ -17,6 +17,7 @@ class ModelArgs(BaseModelArgs):
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
bias: bool = True
|
||||
max_position_embeddings: int = 32768
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
@ -32,8 +33,50 @@ class ModelArgs(BaseModelArgs):
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
if self.rope_scaling["type"] not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"rope_scaling 'type' currently only supports 'linear' or 'dynamic"
|
||||
)
|
||||
|
||||
|
||||
class DynamicNTKScalingRoPE(nn.Module):
|
||||
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_base = base
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
self.scale = scale
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
seq_len = x.shape[1] + offset
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.original_base * (
|
||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||
) ** (self.dims / (self.dims - 2))
|
||||
else:
|
||||
base = self.original_base
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -56,10 +99,12 @@ class Attention(nn.Module):
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
else 2.0
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
|
||||
self.rope = DynamicNTKScalingRoPE(
|
||||
head_dim,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
@ -185,6 +230,10 @@ class Model(nn.Module):
|
||||
out = self.output(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
Loading…
Reference in New Issue
Block a user