From 8a178f87163ecaa74aace9c49616576364866087 Mon Sep 17 00:00:00 2001 From: Anchen Date: Fri, 8 Mar 2024 01:11:35 +1100 Subject: [PATCH] chore: enable tie_word_embeddings config for qwen2 (#544) --- llms/mlx_lm/models/qwen2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index f0c19171..7cb5e106 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -21,6 +21,7 @@ class ModelArgs(BaseModelArgs): rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False def __post_init__(self): if self.num_key_value_heads is None: @@ -176,7 +177,8 @@ class Model(nn.Module): super().__init__() self.model_type = args.model_type self.model = Qwen2Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -184,7 +186,11 @@ class Model(nn.Module): cache=None, ): out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + if hasattr(self, "lm_head"): + return self.lm_head(out), cache + + out = out @ self.model.embed_tokens.weight.T + return out, cache @staticmethod def sanitize(weights):