mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fixes for phi4 mini (#1305)
This commit is contained in:
parent
0f240a4c7e
commit
00a7379070
@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
|
||||||
|
partial_rotary_factor: float = 1.0
|
||||||
max_position_embeddings: int = 131072
|
max_position_embeddings: int = 131072
|
||||||
original_max_position_embeddings: int = 4096
|
original_max_position_embeddings: int = 4096
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.num_key_value_heads is None:
|
if self.num_key_value_heads is None:
|
||||||
@ -59,9 +61,10 @@ class Attention(nn.Module):
|
|||||||
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
|
rope_dim = int(head_dim * args.partial_rotary_factor)
|
||||||
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
||||||
self.rope = SuScaledRotaryEmbedding(
|
self.rope = SuScaledRotaryEmbedding(
|
||||||
head_dim,
|
rope_dim,
|
||||||
base=args.rope_theta,
|
base=args.rope_theta,
|
||||||
max_position_embeddings=args.max_position_embeddings,
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
original_max_position_embeddings=args.original_max_position_embeddings,
|
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||||
@ -74,7 +77,7 @@ class Attention(nn.Module):
|
|||||||
assert isinstance(args.rope_scaling["factor"], float)
|
assert isinstance(args.rope_scaling["factor"], float)
|
||||||
rope_scale = 1 / args.rope_scaling["factor"]
|
rope_scale = 1 / args.rope_scaling["factor"]
|
||||||
self.rope = nn.RoPE(
|
self.rope = nn.RoPE(
|
||||||
head_dim,
|
rope_dim,
|
||||||
traditional=args.rope_traditional,
|
traditional=args.rope_traditional,
|
||||||
base=args.rope_theta,
|
base=args.rope_theta,
|
||||||
scale=rope_scale,
|
scale=rope_scale,
|
||||||
@ -190,7 +193,8 @@ class Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = Phi3Model(args)
|
self.model = Phi3Model(args)
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -200,7 +204,11 @@ class Model(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, mask, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
|
@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module):
|
|||||||
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
||||||
/ math.log(original_max_position_embeddings)
|
/ math.log(original_max_position_embeddings)
|
||||||
)
|
)
|
||||||
|
self.dim = dims
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
def __call__(self, x, offset: int = 0):
|
||||||
|
x[..., : self.dim] = self.scale * x[..., : self.dim]
|
||||||
return mx.fast.rope(
|
return mx.fast.rope(
|
||||||
self.scale * x,
|
x,
|
||||||
x.shape[-1],
|
self.dim,
|
||||||
traditional=False,
|
traditional=False,
|
||||||
base=None,
|
base=None,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
|
Loading…
Reference in New Issue
Block a user