mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add support for InternLM-2.5 (#871)
* fix internlm-2 * formatting * add dynamic ntk rope * formatting * move dynamic scaling rope to intermlm2.py * add default max_position_embeddings
This commit is contained in:
parent
561dcf5643
commit
3d365b612a
@ -17,6 +17,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
bias: bool = True
|
bias: bool = True
|
||||||
|
max_position_embeddings: int = 32768
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: int = None
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
@ -32,8 +33,50 @@ class ModelArgs(BaseModelArgs):
|
|||||||
if not all(key in self.rope_scaling for key in required_keys):
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
if self.rope_scaling["type"] != "linear":
|
if self.rope_scaling["type"] not in ["linear", "dynamic"]:
|
||||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
raise ValueError(
|
||||||
|
"rope_scaling 'type' currently only supports 'linear' or 'dynamic"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicNTKScalingRoPE(nn.Module):
|
||||||
|
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
max_position_embeddings: int = 2048,
|
||||||
|
traditional: bool = False,
|
||||||
|
base: float = 10000,
|
||||||
|
scale: float = 1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.original_base = base
|
||||||
|
self.dims = dims
|
||||||
|
self.traditional = traditional
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
seq_len = x.shape[1] + offset
|
||||||
|
if seq_len > self.max_position_embeddings:
|
||||||
|
base = self.original_base * (
|
||||||
|
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||||
|
) ** (self.dims / (self.dims - 2))
|
||||||
|
else:
|
||||||
|
base = self.original_base
|
||||||
|
|
||||||
|
return mx.fast.rope(
|
||||||
|
x,
|
||||||
|
self.dims,
|
||||||
|
traditional=self.traditional,
|
||||||
|
base=base,
|
||||||
|
scale=self.scale,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -56,10 +99,12 @@ class Attention(nn.Module):
|
|||||||
rope_scale = (
|
rope_scale = (
|
||||||
1 / args.rope_scaling["factor"]
|
1 / args.rope_scaling["factor"]
|
||||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||||
else 1
|
else 2.0
|
||||||
)
|
)
|
||||||
self.rope = nn.RoPE(
|
|
||||||
|
self.rope = DynamicNTKScalingRoPE(
|
||||||
head_dim,
|
head_dim,
|
||||||
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
traditional=args.rope_traditional,
|
traditional=args.rope_traditional,
|
||||||
base=args.rope_theta,
|
base=args.rope_theta,
|
||||||
scale=rope_scale,
|
scale=rope_scale,
|
||||||
@ -185,6 +230,10 @@ class Model(nn.Module):
|
|||||||
out = self.output(out)
|
out = self.output(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
# Remove unused precomputed rotary freqs
|
||||||
|
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
Loading…
Reference in New Issue
Block a user