mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fix token count computation to fix tps measurements (#392)
This commit is contained in:
parent
614de6652f
commit
f8fadf7a17
@ -171,21 +171,22 @@ def generate(
|
|||||||
print(s[skip:], end="", flush=True)
|
print(s[skip:], end="", flush=True)
|
||||||
skip = len(s)
|
skip = len(s)
|
||||||
|
|
||||||
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
token_count = len(tokens)
|
||||||
|
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(tokens[skip:], flush=True)
|
print(token_string[skip:], flush=True)
|
||||||
gen_time = time.perf_counter() - tic
|
gen_time = time.perf_counter() - tic
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
if len(tokens) == 0:
|
if token_count == 0:
|
||||||
print("No tokens generated for this prompt")
|
print("No tokens generated for this prompt")
|
||||||
return
|
return
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = prompt.size / prompt_time
|
||||||
gen_tps = (len(tokens) - 1) / gen_time
|
gen_tps = (token_count - 1) / gen_time
|
||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
|
|
||||||
return tokens
|
return token_string
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: Path) -> nn.Module:
|
def load_model(model_path: Path) -> nn.Module:
|
||||||
|
Loading…
Reference in New Issue
Block a user