remove lm head if unused

This commit is contained in:
Awni Hannun 2025-03-06 06:18:46 -08:00
parent 32d10036de
commit 717e415ad4

View File

@ -196,9 +196,12 @@ class Model(nn.Module):
def sanitize(self, weights): def sanitize(self, weights):
# Remove unused precomputed rotary freqs # Remove unused precomputed rotary freqs
return { weights = {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
} }
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
return weights
@property @property
def layers(self): def layers(self):