Fixes for phi4 mini (#1305)

This commit is contained in:
Awni Hannun 2025-02-26 16:21:54 -08:00 committed by GitHub
parent 0f240a4c7e
commit 00a7379070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 6 deletions

View File

@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
partial_rotary_factor: float = 1.0
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
@ -59,9 +61,10 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_dim = int(head_dim * args.partial_rotary_factor)
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding(
head_dim,
rope_dim,
base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
@ -74,7 +77,7 @@ class Attention(nn.Module):
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
rope_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
@ -190,7 +193,8 @@ class Model(nn.Module):
super().__init__()
self.model_type = args.model_type
self.model = Phi3Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args
def __call__(
@ -200,7 +204,11 @@ class Model(nn.Module):
cache=None,
):
out = self.model(inputs, mask, cache)
return self.lm_head(out)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):

View File

@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module):
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)
self.dim = dims
def __call__(self, x, offset: int = 0):
x[..., : self.dim] = self.scale * x[..., : self.dim]
return mx.fast.rope(
self.scale * x,
x.shape[-1],
x,
self.dim,
traditional=False,
base=None,
scale=1.0,