mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
fix eos handling in stream generate
This commit is contained in:
parent
c9994f80e6
commit
6b209c6d3e
@ -74,7 +74,7 @@ def main():
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
for response in stream_generate(
|
||||
for response, *_ in stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
|
@ -64,7 +64,7 @@ def stopping_criteria(
|
||||
end if it has (`trim_length`).
|
||||
"""
|
||||
if tokens and tokens[-1] == eos_token_id:
|
||||
return StopCondition(stop_met=True, trim_length=1)
|
||||
return StopCondition(stop_met=True, trim_length=0)
|
||||
|
||||
for stop_ids in stop_id_sequences:
|
||||
if len(tokens) >= len(stop_ids):
|
||||
@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.max_tokens = self.body.get("max_completion_tokens", None)
|
||||
if self.max_tokens is None:
|
||||
self.max_tokens = self.body.get("max_tokens", 512)
|
||||
self.temperature = self.body.get("temperature", 1.0)
|
||||
self.temperature = self.body.get("temperature", 0.0)
|
||||
self.top_p = self.body.get("top_p", 1.0)
|
||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||
|
@ -204,8 +204,7 @@ def generate_step(
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
one token and a vector of log probabilities.
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
y = prompt
|
||||
@ -272,19 +271,21 @@ def stream_generate(
|
||||
prompt: Union[str, List[int]],
|
||||
max_tokens: int = 100,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
) -> Generator[Tuple[str, int, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (Union[str, List[int]]): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
max_tokens (int): The ma
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
||||
Tuple[str, int, mx.array]:
|
||||
The next text segment, token, and vector of log probabilities.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
@ -300,10 +301,14 @@ def stream_generate(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
detokenizer.add_token(token)
|
||||
if n == (max_tokens - 1) or token == tokenizer.eos_token_id:
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
# Yield the last segment if streaming
|
||||
|
||||
detokenizer.add_token(token)
|
||||
|
||||
if n == (max_tokens - 1):
|
||||
break
|
||||
|
||||
yield detokenizer.last_segment, token, logits
|
||||
|
||||
detokenizer.finalize()
|
||||
@ -318,7 +323,7 @@ def generate(
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
) -> str:
|
||||
"""
|
||||
Generate a complete response from the model.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user