mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-19 17:41:11 +08:00
chore: enable tie_word_embeddings config for qwen2 (#544)
This commit is contained in:
parent
b8e5eda4fd
commit
8a178f8716
@ -21,6 +21,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rope_theta: float = 1000000
|
rope_theta: float = 1000000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.num_key_value_heads is None:
|
if self.num_key_value_heads is None:
|
||||||
@ -176,6 +177,7 @@ class Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = Qwen2Model(args)
|
self.model = Qwen2Model(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -184,8 +186,12 @@ class Model(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
|
if hasattr(self, "lm_head"):
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out), cache
|
||||||
|
|
||||||
|
out = out @ self.model.embed_tokens.weight.T
|
||||||
|
return out, cache
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sanitize(weights):
|
def sanitize(weights):
|
||||||
# Remove unused precomputed rotary freqs
|
# Remove unused precomputed rotary freqs
|
||||||
|
Loading…
Reference in New Issue
Block a user