chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)

* chore: fix tie_word_embeddings for qwen2

* chore: default tie_word_embeddings to True
This commit is contained in:
Anchen
2024-03-13 15:34:32 +11:00
committed by GitHub
parent 39084e81c2
commit 3535408c99
5 changed files with 101 additions and 22 deletions

View File

@@ -179,8 +179,7 @@ class Model(nn.Module):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
@staticmethod
def sanitize(weights):
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k