diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index 780d8ae6..fc1cfc03 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -17,6 +17,7 @@ class ModelArgs(BaseModelArgs): rms_norm_eps: float vocab_size: int bias: bool = True + max_position_embeddings: int = 32768 num_key_value_heads: int = None rope_theta: float = 10000 rope_traditional: bool = False @@ -32,8 +33,50 @@ class ModelArgs(BaseModelArgs): if not all(key in self.rope_scaling for key in required_keys): raise ValueError(f"rope_scaling must contain keys {required_keys}") - if self.rope_scaling["type"] != "linear": - raise ValueError("rope_scaling 'type' currently only supports 'linear'") + if self.rope_scaling["type"] not in ["linear", "dynamic"]: + raise ValueError( + "rope_scaling 'type' currently only supports 'linear' or 'dynamic" + ) + + +class DynamicNTKScalingRoPE(nn.Module): + """Implements the rotary positional encoding with Dynamic NTK scaling.""" + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scale: float = 1.0, + ): + super().__init__() + self.max_position_embeddings = max_position_embeddings + self.original_base = base + self.dims = dims + self.traditional = traditional + self.scale = scale + + def extra_repr(self): + return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}" + + def __call__(self, x, offset: int = 0): + seq_len = x.shape[1] + offset + if seq_len > self.max_position_embeddings: + base = self.original_base * ( + (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1) + ) ** (self.dims / (self.dims - 2)) + else: + base = self.original_base + + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=base, + scale=self.scale, + offset=offset, + ) class Attention(nn.Module): @@ -56,10 +99,12 @@ class Attention(nn.Module): rope_scale = ( 1 / args.rope_scaling["factor"] if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 + else 2.0 ) - self.rope = nn.RoPE( + + self.rope = DynamicNTKScalingRoPE( head_dim, + max_position_embeddings=args.max_position_embeddings, traditional=args.rope_traditional, base=args.rope_theta, scale=rope_scale, @@ -185,6 +230,10 @@ class Model(nn.Module): out = self.output(out) return out + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k} + @property def layers(self): return self.model.layers