From 5a52899405ee911ce07ca3132e781ce149acd6b5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 23 Mar 2024 15:32:33 -0700 Subject: [PATCH] Partially stream de-tokenization (#609) * partially stream de-tokenization * don't break full response --- llms/mlx_lm/utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 03e0fbd3..4f5f8b15 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -224,6 +224,7 @@ def generate( tic = time.perf_counter() tokens = [] + token_strings = [] skip = 0 REPLACEMENT_CHAR = "\ufffd" @@ -250,15 +251,20 @@ def generate( if formatter: formatter(s[skip:], prob.item()) skip = len(s) - elif REPLACEMENT_CHAR not in s: + elif s[-1] != REPLACEMENT_CHAR: print(s[skip:], end="", flush=True) skip = len(s) + # Reset token cache at line break + if s[-1] == "\n": + tokens = [] + token_strings.append(s) + skip = 0 - token_count = len(tokens) - token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") + token_count = n + 1 + token_strings.append(tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")) if verbose: - print(token_string[skip:], flush=True) + print(token_strings[-1][skip:], flush=True) gen_time = time.perf_counter() - tic print("=" * 10) if token_count == 0: @@ -269,7 +275,7 @@ def generate( print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Generation: {gen_tps:.3f} tokens-per-sec") - return token_string + return "".join(token_strings) def load_model(model_path: Path, lazy: bool = False) -> nn.Module: