Fix missing variable

This commit is contained in:
Shunta Saito 2025-02-28 01:31:15 +09:00
parent 08a8dd2507
commit ab960f80dd

View File

@ -602,7 +602,7 @@ class Model(nn.Module):
if not config.tie_word_embeddings:
self.lm_head: nn.Module = nn.Linear(
config.hidden_size, vocab_size, bias=False
config.hidden_size, self.vocab_size, bias=False
)
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: