diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7b452ea4..117adf0f 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -196,9 +196,12 @@ class Model(nn.Module): def sanitize(self, weights): # Remove unused precomputed rotary freqs - return { + weights = { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } + if self.args.tie_word_embeddings: + weights.pop("lm_head.weight", None) + return weights @property def layers(self):