diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 85d32d5f..c03056a6 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -74,7 +74,7 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - for response in stream_generate( + for response, *_ in stream_generate( model, tokenizer, prompt, diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 9c949291..c1365b36 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -64,7 +64,7 @@ def stopping_criteria( end if it has (`trim_length`). """ if tokens and tokens[-1] == eos_token_id: - return StopCondition(stop_met=True, trim_length=1) + return StopCondition(stop_met=True, trim_length=0) for stop_ids in stop_id_sequences: if len(tokens) >= len(stop_ids): @@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler): self.max_tokens = self.body.get("max_completion_tokens", None) if self.max_tokens is None: self.max_tokens = self.body.get("max_tokens", 512) - self.temperature = self.body.get("temperature", 1.0) + self.temperature = self.body.get("temperature", 0.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 240e5dd9..8893b570 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -204,8 +204,7 @@ def generate_step( when ``kv_bits`` is non-None. Default: ``0``. Yields: - Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. """ y = prompt @@ -272,19 +271,21 @@ def stream_generate( prompt: Union[str, List[int]], max_tokens: int = 100, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Generator[Tuple[str, int, mx.array], None, None]: """ A generator producing text based on the given prompt from the model. Args: - prompt (Union[str, List[int]]): The input prompt. model (nn.Module): The model to use for generation. - max_tokens (int): The ma + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (Union[str, List[int]]): The input prompt string or integer tokens. + max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. Yields: - Generator[Tuple[mx.array, mx.array]]: A generator producing text. + Tuple[str, int, mx.array]: + The next text segment, token, and vector of log probabilities. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) @@ -300,10 +301,14 @@ def stream_generate( range(max_tokens), generate_step(prompt_tokens, model, **kwargs), ): - detokenizer.add_token(token) - if n == (max_tokens - 1) or token == tokenizer.eos_token_id: + if token == tokenizer.eos_token_id: break - # Yield the last segment if streaming + + detokenizer.add_token(token) + + if n == (max_tokens - 1): + break + yield detokenizer.last_segment, token, logits detokenizer.finalize() @@ -318,7 +323,7 @@ def generate( verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> str: """ Generate a complete response from the model.