feat: add mistral tps (#173)

* feat: add mistral tps

* eval params before timing + format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Todsaporn Banjerdkit
2023-12-22 22:55:57 +07:00
committed by GitHub
parent 188a91074b
commit 7ae445f6c7
4 changed files with 22 additions and 9 deletions

View File

@@ -2,6 +2,7 @@
import argparse
import json
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
@@ -204,6 +205,7 @@ def load_model(folder: str):
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
mx.eval(model.parameters())
return model, tokenizer
@@ -265,12 +267,17 @@ if __name__ == "__main__":
model, tokenizer = load_model(args.model_path)
print("[INFO] Starting generation...")
tic = time.time()
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if ntoks == 0:
toc = time.time()
mx.eval(tokens)
prompt_tps = prompt.size / (toc - tic)
tic = time.time()
if (len(tokens) % args.tokens_per_eval) == 0:
mx.eval(tokens)
@@ -282,3 +289,8 @@ if __name__ == "__main__":
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
print("------")
generation_tps = ntoks / (time.time() - tic)
print(
f"Tokens per second: prompt {prompt_tps:.3f}, "
f"generation {generation_tps:.3f}"
)