fix eos handling in stream generate

This commit is contained in:
Awni Hannun 2024-11-06 06:48:05 -08:00
parent c9994f80e6
commit 6b209c6d3e
3 changed files with 18 additions and 13 deletions

View File

@ -74,7 +74,7 @@ def main():
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for response in stream_generate( for response, *_ in stream_generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,

View File

@ -64,7 +64,7 @@ def stopping_criteria(
end if it has (`trim_length`). end if it has (`trim_length`).
""" """
if tokens and tokens[-1] == eos_token_id: if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=1) return StopCondition(stop_met=True, trim_length=0)
for stop_ids in stop_id_sequences: for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids): if len(tokens) >= len(stop_ids):
@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.max_tokens = self.body.get("max_completion_tokens", None) self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None: if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512) self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0) self.temperature = self.body.get("temperature", 0.0)
self.top_p = self.body.get("top_p", 1.0) self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20) self.repetition_context_size = self.body.get("repetition_context_size", 20)

View File

@ -204,8 +204,7 @@ def generate_step(
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
one token and a vector of log probabilities.
""" """
y = prompt y = prompt
@ -272,19 +271,21 @@ def stream_generate(
prompt: Union[str, List[int]], prompt: Union[str, List[int]],
max_tokens: int = 100, max_tokens: int = 100,
**kwargs, **kwargs,
) -> Union[str, Generator[str, None, None]]: ) -> Generator[Tuple[str, int, mx.array], None, None]:
""" """
A generator producing text based on the given prompt from the model. A generator producing text based on the given prompt from the model.
Args: Args:
prompt (Union[str, List[int]]): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
max_tokens (int): The ma tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
Yields: Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text. Tuple[str, int, mx.array]:
The next text segment, token, and vector of log probabilities.
""" """
if not isinstance(tokenizer, TokenizerWrapper): if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
@ -300,10 +301,14 @@ def stream_generate(
range(max_tokens), range(max_tokens),
generate_step(prompt_tokens, model, **kwargs), generate_step(prompt_tokens, model, **kwargs),
): ):
detokenizer.add_token(token) if token == tokenizer.eos_token_id:
if n == (max_tokens - 1) or token == tokenizer.eos_token_id:
break break
# Yield the last segment if streaming
detokenizer.add_token(token)
if n == (max_tokens - 1):
break
yield detokenizer.last_segment, token, logits yield detokenizer.last_segment, token, logits
detokenizer.finalize() detokenizer.finalize()
@ -318,7 +323,7 @@ def generate(
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
**kwargs, **kwargs,
) -> Union[str, Generator[str, None, None]]: ) -> str:
""" """
Generate a complete response from the model. Generate a complete response from the model.