mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Fixes for phi4 mini (#1305)
This commit is contained in:
parent
0f240a4c7e
commit
00a7379070
@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs):
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
|
||||
partial_rotary_factor: float = 1.0
|
||||
max_position_embeddings: int = 131072
|
||||
original_max_position_embeddings: int = 4096
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
@ -59,9 +61,10 @@ class Attention(nn.Module):
|
||||
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_dim = int(head_dim * args.partial_rotary_factor)
|
||||
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
||||
self.rope = SuScaledRotaryEmbedding(
|
||||
head_dim,
|
||||
rope_dim,
|
||||
base=args.rope_theta,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||
@ -74,7 +77,7 @@ class Attention(nn.Module):
|
||||
assert isinstance(args.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / args.rope_scaling["factor"]
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
rope_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
@ -190,7 +193,8 @@ class Model(nn.Module):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = Phi3Model(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@ -200,7 +204,11 @@ class Model(nn.Module):
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
|
@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module):
|
||||
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
||||
/ math.log(original_max_position_embeddings)
|
||||
)
|
||||
self.dim = dims
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
x[..., : self.dim] = self.scale * x[..., : self.dim]
|
||||
return mx.fast.rope(
|
||||
self.scale * x,
|
||||
x.shape[-1],
|
||||
x,
|
||||
self.dim,
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
|
Loading…
Reference in New Issue
Block a user