mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 06:00:19 +08:00
Give all inputs when it's the first time call of model
This commit is contained in:
parent
103c6616c4
commit
28f3f3adab
@ -1606,6 +1606,8 @@ class Model(PlamoPreTrainedModel):
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
||||
|
||||
self._prefill = True
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
# self.post_init()
|
||||
@ -1643,6 +1645,9 @@ class Model(PlamoPreTrainedModel):
|
||||
past_key_values=cache,
|
||||
use_cache=self.config.use_cache,
|
||||
)
|
||||
if self._prefill:
|
||||
model_inputs["input_ids"] = inputs
|
||||
self._prefill = False
|
||||
output = self.forward(**model_inputs)
|
||||
if not isinstance(output, CausalLMOutputWithPast):
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user