Fixes for phi4 mini (#1305)

This commit is contained in:
Awni Hannun 2025-02-26 16:21:54 -08:00 committed by GitHub
parent 0f240a4c7e
commit 00a7379070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 6 deletions

View File

@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
partial_rotary_factor: float = 1.0
max_position_embeddings: int = 131072 max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096 original_max_position_embeddings: int = 4096
tie_word_embeddings: bool = False
def __post_init__(self): def __post_init__(self):
if self.num_key_value_heads is None: if self.num_key_value_heads is None:
@ -59,9 +61,10 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_dim = int(head_dim * args.partial_rotary_factor)
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding( self.rope = SuScaledRotaryEmbedding(
head_dim, rope_dim,
base=args.rope_theta, base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings, max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings, original_max_position_embeddings=args.original_max_position_embeddings,
@ -74,7 +77,7 @@ class Attention(nn.Module):
assert isinstance(args.rope_scaling["factor"], float) assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"] rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE( self.rope = nn.RoPE(
head_dim, rope_dim,
traditional=args.rope_traditional, traditional=args.rope_traditional,
base=args.rope_theta, base=args.rope_theta,
scale=rope_scale, scale=rope_scale,
@ -190,7 +193,8 @@ class Model(nn.Module):
super().__init__() super().__init__()
self.model_type = args.model_type self.model_type = args.model_type
self.model = Phi3Model(args) self.model = Phi3Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args self.args = args
def __call__( def __call__(
@ -200,7 +204,11 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out = self.model(inputs, mask, cache) out = self.model(inputs, mask, cache)
return self.lm_head(out) if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property @property
def layers(self): def layers(self):

View File

@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module):
+ math.log(max_position_embeddings / original_max_position_embeddings) + math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings) / math.log(original_max_position_embeddings)
) )
self.dim = dims
def __call__(self, x, offset: int = 0): def __call__(self, x, offset: int = 0):
x[..., : self.dim] = self.scale * x[..., : self.dim]
return mx.fast.rope( return mx.fast.rope(
self.scale * x, x,
x.shape[-1], self.dim,
traditional=False, traditional=False,
base=None, base=None,
scale=1.0, scale=1.0,