mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
Give all inputs when it's the first time call of model
This commit is contained in:
parent
103c6616c4
commit
28f3f3adab
@ -1607,6 +1607,8 @@ class Model(PlamoPreTrainedModel):
|
|||||||
if not config.tie_word_embeddings:
|
if not config.tie_word_embeddings:
|
||||||
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
||||||
|
|
||||||
|
self._prefill = True
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
# self.post_init()
|
# self.post_init()
|
||||||
|
|
||||||
@ -1643,6 +1645,9 @@ class Model(PlamoPreTrainedModel):
|
|||||||
past_key_values=cache,
|
past_key_values=cache,
|
||||||
use_cache=self.config.use_cache,
|
use_cache=self.config.use_cache,
|
||||||
)
|
)
|
||||||
|
if self._prefill:
|
||||||
|
model_inputs["input_ids"] = inputs
|
||||||
|
self._prefill = False
|
||||||
output = self.forward(**model_inputs)
|
output = self.forward(**model_inputs)
|
||||||
if not isinstance(output, CausalLMOutputWithPast):
|
if not isinstance(output, CausalLMOutputWithPast):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
Reference in New Issue
Block a user