From 00a73790702991075b80f9facf219ae397e1eb15 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 26 Feb 2025 16:21:54 -0800 Subject: [PATCH] Fixes for phi4 mini (#1305) --- llms/mlx_lm/models/phi3.py | 16 ++++++++++++---- llms/mlx_lm/models/su_rope.py | 6 ++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index d1c21e25..63e985de 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs): rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None + partial_rotary_factor: float = 1.0 max_position_embeddings: int = 131072 original_max_position_embeddings: int = 4096 + tie_word_embeddings: bool = False def __post_init__(self): if self.num_key_value_heads is None: @@ -59,9 +61,10 @@ class Attention(nn.Module): self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + rope_dim = int(head_dim * args.partial_rotary_factor) if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: self.rope = SuScaledRotaryEmbedding( - head_dim, + rope_dim, base=args.rope_theta, max_position_embeddings=args.max_position_embeddings, original_max_position_embeddings=args.original_max_position_embeddings, @@ -74,7 +77,7 @@ class Attention(nn.Module): assert isinstance(args.rope_scaling["factor"], float) rope_scale = 1 / args.rope_scaling["factor"] self.rope = nn.RoPE( - head_dim, + rope_dim, traditional=args.rope_traditional, base=args.rope_theta, scale=rope_scale, @@ -190,7 +193,8 @@ class Model(nn.Module): super().__init__() self.model_type = args.model_type self.model = Phi3Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.args = args def __call__( @@ -200,7 +204,11 @@ class Model(nn.Module): cache=None, ): out = self.model(inputs, mask, cache) - return self.lm_head(out) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out @property def layers(self): diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py index 9c414afd..6340c77b 100644 --- a/llms/mlx_lm/models/su_rope.py +++ b/llms/mlx_lm/models/su_rope.py @@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module): + math.log(max_position_embeddings / original_max_position_embeddings) / math.log(original_max_position_embeddings) ) + self.dim = dims def __call__(self, x, offset: int = 0): + x[..., : self.dim] = self.scale * x[..., : self.dim] return mx.fast.rope( - self.scale * x, - x.shape[-1], + x, + self.dim, traditional=False, base=None, scale=1.0,