diff --git a/llms/hf_llm/generate.py b/llms/hf_llm/generate.py index d0b41fe0..0b8f7ea2 100644 --- a/llms/hf_llm/generate.py +++ b/llms/hf_llm/generate.py @@ -46,6 +46,9 @@ def generate( print(tokenizer.decode(tokens)[skip:], flush=True) gen_time = time.time() - tic print("=" * 10) + if len(tokens) == 0: + print("No tokens generated for this prompt") + return prompt_tps = prompt.size / prompt_time gen_tps = (len(tokens) - 1) / gen_time print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")