From 28f3f3adab2a8af1886014ac467e67891bbfcb6b Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Fri, 14 Feb 2025 23:24:45 +0900 Subject: [PATCH] Give all inputs when it's the first time call of model --- llms/mlx_lm/models/plamo2.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index f0f5d25b..0fbb4cee 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -1606,6 +1606,8 @@ class Model(PlamoPreTrainedModel): if not config.tie_word_embeddings: self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False) + + self._prefill = True # Initialize weights and apply final processing # self.post_init() @@ -1643,6 +1645,9 @@ class Model(PlamoPreTrainedModel): past_key_values=cache, use_cache=self.config.use_cache, ) + if self._prefill: + model_inputs["input_ids"] = inputs + self._prefill = False output = self.forward(**model_inputs) if not isinstance(output, CausalLMOutputWithPast): raise ValueError(