diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index e9410e0f..14983350 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -602,7 +602,7 @@ class Model(nn.Module): if not config.tie_word_embeddings: self.lm_head: nn.Module = nn.Linear( - config.hidden_size, vocab_size, bias=False + config.hidden_size, self.vocab_size, bias=False ) def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: