mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
Measure tokens/s
This commit is contained in:
parent
90d3a15ba2
commit
29bfb93455
9
t5/t5.py
9
t5/t5.py
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from time import perf_counter_ns
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -398,6 +399,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
|
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
|
||||||
|
|
||||||
|
start = perf_counter_ns()
|
||||||
|
n_tokens = 0
|
||||||
tokens = []
|
tokens = []
|
||||||
for token, _ in zip(
|
for token, _ in zip(
|
||||||
generate(prompt, decoder_inputs, model, args.temp),
|
generate(prompt, decoder_inputs, model, args.temp),
|
||||||
@ -417,10 +420,16 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
print(s, end="", flush=True)
|
print(s, end="", flush=True)
|
||||||
|
n_tokens += len(tokens)
|
||||||
tokens = []
|
tokens = []
|
||||||
if eos_index is not None:
|
if eos_index is not None:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
end = perf_counter_ns()
|
||||||
|
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
|
n_tokens += len(tokens)
|
||||||
|
elapsed = (end - start) / 1.0e9
|
||||||
print(s, flush=True)
|
print(s, flush=True)
|
||||||
|
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")
|
Loading…
Reference in New Issue
Block a user