diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index f0f5d25b..0fbb4cee 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -1606,6 +1606,8 @@ class Model(PlamoPreTrainedModel): if not config.tie_word_embeddings: self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False) + + self._prefill = True # Initialize weights and apply final processing # self.post_init() @@ -1643,6 +1645,9 @@ class Model(PlamoPreTrainedModel): past_key_values=cache, use_cache=self.config.use_cache, ) + if self._prefill: + model_inputs["input_ids"] = inputs + self._prefill = False output = self.forward(**model_inputs) if not isinstance(output, CausalLMOutputWithPast): raise ValueError(