update: pre-commit format hook

This commit is contained in:
B1ACK917 2025-03-12 23:10:58 +09:00
parent 346d9641d7
commit f5cd999774
No known key found for this signature in database

View File

@ -1,3 +1,8 @@
# @Author : Dawei Feng
# @Time : 2025/3/12 22:00
# @File : bench.py
# @Email : darkv.feng@outlook.com
"""
MLX-LM Benchmark Tool
@ -14,19 +19,21 @@ It supports multiple input values for model, prompt tokens, and generation token
"""
import argparse
import contextlib
import csv
import io
import json
import logging
import random
import re
import io
import contextlib
import csv
import time
from typing import Any, Dict, List, Optional, Union
from .utils import load, generate
from .utils import generate, load
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
class CommaSeparatedIntegers(argparse.Action):
@ -88,7 +95,11 @@ def parse_args() -> argparse.Namespace:
help="Outout Sequence Length (OSL). Number of tokens to generate. Accepts multiple comma-separated values.",
)
parser.add_argument(
"-r", "--repetitions", type=int, default=5, help="Number of benchmark repetitions to average results over."
"-r",
"--repetitions",
type=int,
default=5,
help="Number of benchmark repetitions to average results over.",
)
parser.add_argument(
"-o",
@ -168,7 +179,9 @@ def generate_synthetic_tokens(tokenizer: Any, seq_length: int) -> List[int]:
# Prepend BOS token if available; otherwise, start with an empty list.
tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
tokens += [random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens))]
tokens += [
random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens))
]
return tokens
@ -199,13 +212,17 @@ def parse_metrics(log_output: str) -> Dict[str, Optional[float]]:
}
# Extract prompt tokens and tokens-per-second
prompt_match = re.search(r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output)
prompt_match = re.search(
r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output
)
if prompt_match:
metrics["prompt_tokens"] = int(prompt_match.group(1))
metrics["prompt_tps"] = float(prompt_match.group(2))
# Extract generation tokens and tokens-per-second
generation_match = re.search(r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output)
generation_match = re.search(
r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output
)
if generation_match:
metrics["response_tokens"] = int(generation_match.group(1))
metrics["response_tps"] = float(generation_match.group(2))
@ -248,12 +265,23 @@ def benchmark_performance(
input_tokens = generate_synthetic_tokens(tokenizer, seq_length)
output_buffer = io.StringIO()
with contextlib.redirect_stdout(output_buffer):
generate(model, tokenizer, input_tokens, max_tokens=max_tokens, verbose=True, **generate_kwargs)
generate(
model,
tokenizer,
input_tokens,
max_tokens=max_tokens,
verbose=True,
**generate_kwargs,
)
captured_output = output_buffer.getvalue()
return parse_metrics(captured_output)
def save_results(output_file, results: Union[Dict[str, Any], List[Dict[str, Any]]], output_format: str) -> None:
def save_results(
output_file,
results: Union[Dict[str, Any], List[Dict[str, Any]]],
output_format: str,
) -> None:
"""
Save the benchmark results in the specified output format.
@ -346,18 +374,31 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]:
for n_prompt in args.n_prompt:
for n_gen in args.n_gen:
# Warmup run
_ = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args)
_ = benchmark_performance(
model, tokenizer, n_prompt, n_gen, **args.gen_args
)
# Benchmark iterations
metrics_list = []
for i in range(args.repetitions):
metrics = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args)
metrics = benchmark_performance(
model, tokenizer, n_prompt, n_gen, **args.gen_args
)
metrics_list.append(metrics)
# Compute average metrics
avg_metrics = {}
keys = ["prompt_tokens", "prompt_tps", "response_tokens", "response_tps", "exec_time", "ram_usage"]
keys = [
"prompt_tokens",
"prompt_tps",
"response_tokens",
"response_tps",
"exec_time",
"ram_usage",
]
for key in keys:
valid_values = [m[key] for m in metrics_list if m[key] is not None]
avg_metrics[key] = sum(valid_values) / len(valid_values) if valid_values else None
avg_metrics[key] = (
sum(valid_values) / len(valid_values) if valid_values else None
)
result = {
"Model": model_path,
"Model Load Time (s)": round(model_load_time, 3),
@ -367,7 +408,9 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]:
"Response TPS": round(avg_metrics["response_tps"], 3),
"Execution Time (s)": round(avg_metrics["exec_time"], 3),
"Memory Usage (GB)": (
round(avg_metrics["ram_usage"], 2) if avg_metrics["ram_usage"] is not None else None
round(avg_metrics["ram_usage"], 2)
if avg_metrics["ram_usage"] is not None
else None
),
}
# Print the result row immediately after each test completes