Fix token count computation to fix tps measurements (#392)

This commit is contained in:
David Koski 2024-01-30 11:24:16 -08:00 committed by GitHub
parent 614de6652f
commit f8fadf7a17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -171,21 +171,22 @@ def generate(
print(s[skip:], end="", flush=True)
skip = len(s)
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
token_count = len(tokens)
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
if verbose:
print(tokens[skip:], flush=True)
print(token_string[skip:], flush=True)
gen_time = time.perf_counter() - tic
print("=" * 10)
if len(tokens) == 0:
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
return tokens
return token_string
def load_model(model_path: Path) -> nn.Module: