From 566af15c34d79255fd75639b42827f0753c65226 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 7 Nov 2024 19:43:42 -0800 Subject: [PATCH] fixes --- llms/README.md | 6 +++--- llms/mlx_lm/chat.py | 4 ++-- llms/mlx_lm/server.py | 3 ++- llms/mlx_lm/utils.py | 13 ++++++------- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/llms/README.md b/llms/README.md index eeb3ed6a..4976c39e 100644 --- a/llms/README.md +++ b/llms/README.md @@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -response = generate(model, tokenizer, prompt=prompt, verbose=True) +text = generate(model, tokenizer, prompt=prompt, verbose=True) ``` To see a description of all the arguments you can do: @@ -117,8 +117,8 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): - print(t, end="", flush=True) +for response in stream_generate(model, tokenizer, prompt, max_tokens=512): + print(response.text, end="", flush=True) print() ``` diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index c03056a6..09a39e59 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -74,7 +74,7 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - for response, *_ in stream_generate( + for response in stream_generate( model, tokenizer, prompt, @@ -83,7 +83,7 @@ def main(): top_p=args.top_p, prompt_cache=prompt_cache, ): - print(response, flush=True, end="") + print(response.text, flush=True, end="") print() diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index a71c305e..d96a3f72 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -475,7 +475,8 @@ class APIHandler(BaseHTTPRequestHandler): logit_bias=self.logit_bias, prompt_cache=self.prompt_cache.cache, ): - text += gen_response.text + segment = gen_response.text + text += segment logging.debug(text) token = gen_response.token logprobs = gen_response.logprobs diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 33f32cc8..15b0af2d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -378,29 +378,28 @@ def generate( model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. + max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. kwargs: The remaining options get passed to :func:`stream_generate`. See :func:`stream_generate` for more details. """ if formatter is not None: - print( - "Text formatting has been deprecated and will be removed in the next version." - ) + print("Text formatting is deprecated and will be removed in the next version.") if verbose: print("=" * 10) print("Prompt:", prompt) - full_text = "" + text = "" for response in stream_generate(model, tokenizer, prompt, **kwargs): if verbose: print(response.text, end="", flush=True) - full_text += response.text + text += response.text if verbose: print() print("=" * 10) - if len(full_text) == 0: + if len(text) == 0: print("No text generated for this prompt") return print( @@ -412,7 +411,7 @@ def generate( f"{response.generation_tps:.3f} tokens-per-sec" ) print(f"Peak memory: {response.peak_memory:.3f} GB") - return full_text + return text def load_config(model_path: Path) -> dict: