chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)

* chore: fix tie_word_embeddings for qwen2

* chore: default tie_word_embeddings to True
This commit is contained in:
Anchen
2024-03-13 15:34:32 +11:00
committed by GitHub
parent 39084e81c2
commit 3535408c99
5 changed files with 101 additions and 22 deletions

View File

@@ -21,7 +21,7 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
@@ -168,10 +168,10 @@ class Qwen2Model(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -179,14 +179,11 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
if hasattr(self, "lm_head"):
return self.lm_head(out), cache
return self.lm_head(out), cache
out = out @ self.model.embed_tokens.weight.T
return out, cache
@staticmethod
def sanitize(weights):
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k