update: pre-commit format hook

This commit is contained in:
B1ACK917 2025-03-12 23:10:58 +09:00
parent 346d9641d7
commit f5cd999774
No known key found for this signature in database

View File

@ -1,3 +1,8 @@
# @Author : Dawei Feng
# @Time : 2025/3/12 22:00
# @File : bench.py
# @Email : darkv.feng@outlook.com
""" """
MLX-LM Benchmark Tool MLX-LM Benchmark Tool
@ -14,19 +19,21 @@ It supports multiple input values for model, prompt tokens, and generation token
""" """
import argparse import argparse
import contextlib
import csv
import io
import json import json
import logging import logging
import random import random
import re import re
import io
import contextlib
import csv
import time import time
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from .utils import load, generate from .utils import generate, load
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
class CommaSeparatedIntegers(argparse.Action): class CommaSeparatedIntegers(argparse.Action):
@ -88,7 +95,11 @@ def parse_args() -> argparse.Namespace:
help="Outout Sequence Length (OSL). Number of tokens to generate. Accepts multiple comma-separated values.", help="Outout Sequence Length (OSL). Number of tokens to generate. Accepts multiple comma-separated values.",
) )
parser.add_argument( parser.add_argument(
"-r", "--repetitions", type=int, default=5, help="Number of benchmark repetitions to average results over." "-r",
"--repetitions",
type=int,
default=5,
help="Number of benchmark repetitions to average results over.",
) )
parser.add_argument( parser.add_argument(
"-o", "-o",
@ -168,7 +179,9 @@ def generate_synthetic_tokens(tokenizer: Any, seq_length: int) -> List[int]:
# Prepend BOS token if available; otherwise, start with an empty list. # Prepend BOS token if available; otherwise, start with an empty list.
tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else [] tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
tokens += [random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens))] tokens += [
random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens))
]
return tokens return tokens
@ -199,13 +212,17 @@ def parse_metrics(log_output: str) -> Dict[str, Optional[float]]:
} }
# Extract prompt tokens and tokens-per-second # Extract prompt tokens and tokens-per-second
prompt_match = re.search(r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output) prompt_match = re.search(
r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output
)
if prompt_match: if prompt_match:
metrics["prompt_tokens"] = int(prompt_match.group(1)) metrics["prompt_tokens"] = int(prompt_match.group(1))
metrics["prompt_tps"] = float(prompt_match.group(2)) metrics["prompt_tps"] = float(prompt_match.group(2))
# Extract generation tokens and tokens-per-second # Extract generation tokens and tokens-per-second
generation_match = re.search(r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output) generation_match = re.search(
r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output
)
if generation_match: if generation_match:
metrics["response_tokens"] = int(generation_match.group(1)) metrics["response_tokens"] = int(generation_match.group(1))
metrics["response_tps"] = float(generation_match.group(2)) metrics["response_tps"] = float(generation_match.group(2))
@ -248,12 +265,23 @@ def benchmark_performance(
input_tokens = generate_synthetic_tokens(tokenizer, seq_length) input_tokens = generate_synthetic_tokens(tokenizer, seq_length)
output_buffer = io.StringIO() output_buffer = io.StringIO()
with contextlib.redirect_stdout(output_buffer): with contextlib.redirect_stdout(output_buffer):
generate(model, tokenizer, input_tokens, max_tokens=max_tokens, verbose=True, **generate_kwargs) generate(
model,
tokenizer,
input_tokens,
max_tokens=max_tokens,
verbose=True,
**generate_kwargs,
)
captured_output = output_buffer.getvalue() captured_output = output_buffer.getvalue()
return parse_metrics(captured_output) return parse_metrics(captured_output)
def save_results(output_file, results: Union[Dict[str, Any], List[Dict[str, Any]]], output_format: str) -> None: def save_results(
output_file,
results: Union[Dict[str, Any], List[Dict[str, Any]]],
output_format: str,
) -> None:
""" """
Save the benchmark results in the specified output format. Save the benchmark results in the specified output format.
@ -346,18 +374,31 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]:
for n_prompt in args.n_prompt: for n_prompt in args.n_prompt:
for n_gen in args.n_gen: for n_gen in args.n_gen:
# Warmup run # Warmup run
_ = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args) _ = benchmark_performance(
model, tokenizer, n_prompt, n_gen, **args.gen_args
)
# Benchmark iterations # Benchmark iterations
metrics_list = [] metrics_list = []
for i in range(args.repetitions): for i in range(args.repetitions):
metrics = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args) metrics = benchmark_performance(
model, tokenizer, n_prompt, n_gen, **args.gen_args
)
metrics_list.append(metrics) metrics_list.append(metrics)
# Compute average metrics # Compute average metrics
avg_metrics = {} avg_metrics = {}
keys = ["prompt_tokens", "prompt_tps", "response_tokens", "response_tps", "exec_time", "ram_usage"] keys = [
"prompt_tokens",
"prompt_tps",
"response_tokens",
"response_tps",
"exec_time",
"ram_usage",
]
for key in keys: for key in keys:
valid_values = [m[key] for m in metrics_list if m[key] is not None] valid_values = [m[key] for m in metrics_list if m[key] is not None]
avg_metrics[key] = sum(valid_values) / len(valid_values) if valid_values else None avg_metrics[key] = (
sum(valid_values) / len(valid_values) if valid_values else None
)
result = { result = {
"Model": model_path, "Model": model_path,
"Model Load Time (s)": round(model_load_time, 3), "Model Load Time (s)": round(model_load_time, 3),
@ -367,7 +408,9 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]:
"Response TPS": round(avg_metrics["response_tps"], 3), "Response TPS": round(avg_metrics["response_tps"], 3),
"Execution Time (s)": round(avg_metrics["exec_time"], 3), "Execution Time (s)": round(avg_metrics["exec_time"], 3),
"Memory Usage (GB)": ( "Memory Usage (GB)": (
round(avg_metrics["ram_usage"], 2) if avg_metrics["ram_usage"] is not None else None round(avg_metrics["ram_usage"], 2)
if avg_metrics["ram_usage"] is not None
else None
), ),
} }
# Print the result row immediately after each test completes # Print the result row immediately after each test completes