chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)

* chore: fix tie_word_embeddings for qwen2

* chore: default tie_word_embeddings to True
This commit is contained in:
Anchen
2024-03-13 15:34:32 +11:00
committed by GitHub
parent 39084e81c2
commit 3535408c99
5 changed files with 101 additions and 22 deletions

View File

@@ -147,11 +147,10 @@ class Starcoder2Model(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Starcoder2Model(args)
# For 15B starcoder2 and fine-tuned models which don't tie word embeddings
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -159,11 +158,12 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
if not self.model.args.tie_word_embeddings:
return self.lm_head(out), cache
else:
out = out @ self.model.embed_tokens.weight.T
return out, cache
return self.lm_head(out), cache
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return weights
@property
def layers(self):