From 8a178f87163ecaa74aace9c49616576364866087 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 8 Mar 2024 01:11:35 +1100
Subject: [PATCH] chore: enable tie_word_embeddings config for qwen2 (#544)
---
llms/mlx_lm/models/qwen2.py | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py
index f0c19171..7cb5e106 100644
--- a/llms/mlx_lm/models/qwen2.py
+++ b/llms/mlx_lm/models/qwen2.py
@@ -21,6 +21,7 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+ tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
@@ -176,7 +177,8 @@ class Model(nn.Module):
super().__init__()
self.model_type = args.model_type
self.model = Qwen2Model(args)
- self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+ if not args.tie_word_embeddings:
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -184,7 +186,11 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
- return self.lm_head(out), cache
+ if hasattr(self, "lm_head"):
+ return self.lm_head(out), cache
+
+ out = out @ self.model.embed_tokens.weight.T
+ return out, cache
@staticmethod
def sanitize(weights):