diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index f0c19171..7cb5e106 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -21,6 +21,7 @@ class ModelArgs(BaseModelArgs): rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False def __post_init__(self): if self.num_key_value_heads is None: @@ -176,7 +177,8 @@ class Model(nn.Module): super().__init__() self.model_type = args.model_type self.model = Qwen2Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -184,7 +186,11 @@ class Model(nn.Module): cache=None, ): out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + if hasattr(self, "lm_head"): + return self.lm_head(out), cache + + out = out @ self.model.embed_tokens.weight.T + return out, cache @staticmethod def sanitize(weights):