Measure tokens/s

This commit is contained in:
Juarez Bochi 2023-12-17 10:53:49 -05:00
parent 90d3a15ba2
commit 29bfb93455
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -1,6 +1,7 @@
import argparse
from dataclasses import dataclass
from typing import Optional
from time import perf_counter_ns
import numpy as np
import mlx.core as mx
@ -398,6 +399,8 @@ if __name__ == "__main__":
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
start = perf_counter_ns()
n_tokens = 0
tokens = []
for token, _ in zip(
generate(prompt, decoder_inputs, model, args.temp),
@ -417,10 +420,16 @@ if __name__ == "__main__":
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
n_tokens += len(tokens)
tokens = []
if eos_index is not None:
break
end = perf_counter_ns()
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
n_tokens += len(tokens)
elapsed = (end - start) / 1.0e9
print(s, flush=True)
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")