From f8fadf7a178013e927f625481617f9d813ed6dac Mon Sep 17 00:00:00 2001 From: David Koski <46639364+davidkoski@users.noreply.github.com> Date: Tue, 30 Jan 2024 11:24:16 -0800 Subject: [PATCH] Fix token count computation to fix tps measurements (#392) --- llms/mlx_lm/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 98ec7980..726342e0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -171,21 +171,22 @@ def generate( print(s[skip:], end="", flush=True) skip = len(s) - tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") + token_count = len(tokens) + token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") if verbose: - print(tokens[skip:], flush=True) + print(token_string[skip:], flush=True) gen_time = time.perf_counter() - tic print("=" * 10) - if len(tokens) == 0: + if token_count == 0: print("No tokens generated for this prompt") return prompt_tps = prompt.size / prompt_time - gen_tps = (len(tokens) - 1) / gen_time + gen_tps = (token_count - 1) / gen_time print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Generation: {gen_tps:.3f} tokens-per-sec") - return tokens + return token_string def load_model(model_path: Path) -> nn.Module: