mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
fix eos handling in stream generate
This commit is contained in:
parent
c9994f80e6
commit
6b209c6d3e
@ -74,7 +74,7 @@ def main():
|
|||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
for response in stream_generate(
|
for response, *_ in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
|
@ -64,7 +64,7 @@ def stopping_criteria(
|
|||||||
end if it has (`trim_length`).
|
end if it has (`trim_length`).
|
||||||
"""
|
"""
|
||||||
if tokens and tokens[-1] == eos_token_id:
|
if tokens and tokens[-1] == eos_token_id:
|
||||||
return StopCondition(stop_met=True, trim_length=1)
|
return StopCondition(stop_met=True, trim_length=0)
|
||||||
|
|
||||||
for stop_ids in stop_id_sequences:
|
for stop_ids in stop_id_sequences:
|
||||||
if len(tokens) >= len(stop_ids):
|
if len(tokens) >= len(stop_ids):
|
||||||
@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.max_tokens = self.body.get("max_completion_tokens", None)
|
self.max_tokens = self.body.get("max_completion_tokens", None)
|
||||||
if self.max_tokens is None:
|
if self.max_tokens is None:
|
||||||
self.max_tokens = self.body.get("max_tokens", 512)
|
self.max_tokens = self.body.get("max_tokens", 512)
|
||||||
self.temperature = self.body.get("temperature", 1.0)
|
self.temperature = self.body.get("temperature", 0.0)
|
||||||
self.top_p = self.body.get("top_p", 1.0)
|
self.top_p = self.body.get("top_p", 1.0)
|
||||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||||
|
@ -204,8 +204,7 @@ def generate_step(
|
|||||||
when ``kv_bits`` is non-None. Default: ``0``.
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||||
one token and a vector of log probabilities.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
@ -272,19 +271,21 @@ def stream_generate(
|
|||||||
prompt: Union[str, List[int]],
|
prompt: Union[str, List[int]],
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, Generator[str, None, None]]:
|
) -> Generator[Tuple[str, int, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing text based on the given prompt from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (Union[str, List[int]]): The input prompt.
|
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
max_tokens (int): The ma
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
|
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||||
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
See :func:`generate_step` for more details.
|
See :func:`generate_step` for more details.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
Tuple[str, int, mx.array]:
|
||||||
|
The next text segment, token, and vector of log probabilities.
|
||||||
"""
|
"""
|
||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
@ -300,10 +301,14 @@ def stream_generate(
|
|||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
):
|
):
|
||||||
detokenizer.add_token(token)
|
if token == tokenizer.eos_token_id:
|
||||||
if n == (max_tokens - 1) or token == tokenizer.eos_token_id:
|
|
||||||
break
|
break
|
||||||
# Yield the last segment if streaming
|
|
||||||
|
detokenizer.add_token(token)
|
||||||
|
|
||||||
|
if n == (max_tokens - 1):
|
||||||
|
break
|
||||||
|
|
||||||
yield detokenizer.last_segment, token, logits
|
yield detokenizer.last_segment, token, logits
|
||||||
|
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
@ -318,7 +323,7 @@ def generate(
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, Generator[str, None, None]]:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a complete response from the model.
|
Generate a complete response from the model.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user