diff --git a/t5/t5.py b/t5/t5.py index 571abbfb..b44f054a 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,6 +1,7 @@ import argparse from dataclasses import dataclass from typing import Optional +from time import perf_counter_ns import numpy as np import mlx.core as mx @@ -398,6 +399,8 @@ if __name__ == "__main__": decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32) + start = perf_counter_ns() + n_tokens = 0 tokens = [] for token, _ in zip( generate(prompt, decoder_inputs, model, args.temp), @@ -417,10 +420,16 @@ if __name__ == "__main__": s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) + n_tokens += len(tokens) tokens = [] if eos_index is not None: break + end = perf_counter_ns() + mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens]) + n_tokens += len(tokens) + elapsed = (end - start) / 1.0e9 print(s, flush=True) + print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}") \ No newline at end of file