mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
feat: add mistral tps (#173)
* feat: add mistral tps * eval params before timing + format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
188a91074b
commit
7ae445f6c7
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
@@ -204,6 +205,7 @@ def load_model(folder: str):
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@@ -265,12 +267,17 @@ if __name__ == "__main__":
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
|
||||
print("[INFO] Starting generation...")
|
||||
|
||||
tic = time.time()
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
tokens = []
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
if ntoks == 0:
|
||||
toc = time.time()
|
||||
mx.eval(tokens)
|
||||
prompt_tps = prompt.size / (toc - tic)
|
||||
tic = time.time()
|
||||
|
||||
if (len(tokens) % args.tokens_per_eval) == 0:
|
||||
mx.eval(tokens)
|
||||
@@ -282,3 +289,8 @@ if __name__ == "__main__":
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
||||
print("------")
|
||||
generation_tps = ntoks / (time.time() - tic)
|
||||
print(
|
||||
f"Tokens per second: prompt {prompt_tps:.3f}, "
|
||||
f"generation {generation_tps:.3f}"
|
||||
)
|
||||
|
Reference in New Issue
Block a user