fix eos handling in stream generate

This commit is contained in:
Awni Hannun 2024-11-06 06:48:05 -08:00
parent c9994f80e6
commit 6b209c6d3e
3 changed files with 18 additions and 13 deletions

View File

@ -74,7 +74,7 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate(
for response, *_ in stream_generate(
model,
tokenizer,
prompt,

View File

@ -64,7 +64,7 @@ def stopping_criteria(
end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=1)
return StopCondition(stop_met=True, trim_length=0)
for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids):
@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0)
self.temperature = self.body.get("temperature", 0.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20)

View File

@ -204,8 +204,7 @@ def generate_step(
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
@ -272,19 +271,21 @@ def stream_generate(
prompt: Union[str, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> Generator[Tuple[str, int, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (Union[str, List[int]]): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The ma
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
Tuple[str, int, mx.array]:
The next text segment, token, and vector of log probabilities.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
@ -300,10 +301,14 @@ def stream_generate(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
detokenizer.add_token(token)
if n == (max_tokens - 1) or token == tokenizer.eos_token_id:
if token == tokenizer.eos_token_id:
break
# Yield the last segment if streaming
detokenizer.add_token(token)
if n == (max_tokens - 1):
break
yield detokenizer.last_segment, token, logits
detokenizer.finalize()
@ -318,7 +323,7 @@ def generate(
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> str:
"""
Generate a complete response from the model.