mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
update: pre-commit format hook
This commit is contained in:
parent
346d9641d7
commit
f5cd999774
@ -1,3 +1,8 @@
|
||||
# @Author : Dawei Feng
|
||||
# @Time : 2025/3/12 22:00
|
||||
# @File : bench.py
|
||||
# @Email : darkv.feng@outlook.com
|
||||
|
||||
"""
|
||||
MLX-LM Benchmark Tool
|
||||
|
||||
@ -14,19 +19,21 @@ It supports multiple input values for model, prompt tokens, and generation token
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import io
|
||||
import contextlib
|
||||
import csv
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .utils import load, generate
|
||||
from .utils import generate, load
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class CommaSeparatedIntegers(argparse.Action):
|
||||
@ -88,7 +95,11 @@ def parse_args() -> argparse.Namespace:
|
||||
help="Outout Sequence Length (OSL). Number of tokens to generate. Accepts multiple comma-separated values.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r", "--repetitions", type=int, default=5, help="Number of benchmark repetitions to average results over."
|
||||
"-r",
|
||||
"--repetitions",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of benchmark repetitions to average results over.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
@ -168,7 +179,9 @@ def generate_synthetic_tokens(tokenizer: Any, seq_length: int) -> List[int]:
|
||||
|
||||
# Prepend BOS token if available; otherwise, start with an empty list.
|
||||
tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
|
||||
tokens += [random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens))]
|
||||
tokens += [
|
||||
random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens))
|
||||
]
|
||||
|
||||
return tokens
|
||||
|
||||
@ -199,13 +212,17 @@ def parse_metrics(log_output: str) -> Dict[str, Optional[float]]:
|
||||
}
|
||||
|
||||
# Extract prompt tokens and tokens-per-second
|
||||
prompt_match = re.search(r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output)
|
||||
prompt_match = re.search(
|
||||
r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output
|
||||
)
|
||||
if prompt_match:
|
||||
metrics["prompt_tokens"] = int(prompt_match.group(1))
|
||||
metrics["prompt_tps"] = float(prompt_match.group(2))
|
||||
|
||||
# Extract generation tokens and tokens-per-second
|
||||
generation_match = re.search(r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output)
|
||||
generation_match = re.search(
|
||||
r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output
|
||||
)
|
||||
if generation_match:
|
||||
metrics["response_tokens"] = int(generation_match.group(1))
|
||||
metrics["response_tps"] = float(generation_match.group(2))
|
||||
@ -248,12 +265,23 @@ def benchmark_performance(
|
||||
input_tokens = generate_synthetic_tokens(tokenizer, seq_length)
|
||||
output_buffer = io.StringIO()
|
||||
with contextlib.redirect_stdout(output_buffer):
|
||||
generate(model, tokenizer, input_tokens, max_tokens=max_tokens, verbose=True, **generate_kwargs)
|
||||
generate(
|
||||
model,
|
||||
tokenizer,
|
||||
input_tokens,
|
||||
max_tokens=max_tokens,
|
||||
verbose=True,
|
||||
**generate_kwargs,
|
||||
)
|
||||
captured_output = output_buffer.getvalue()
|
||||
return parse_metrics(captured_output)
|
||||
|
||||
|
||||
def save_results(output_file, results: Union[Dict[str, Any], List[Dict[str, Any]]], output_format: str) -> None:
|
||||
def save_results(
|
||||
output_file,
|
||||
results: Union[Dict[str, Any], List[Dict[str, Any]]],
|
||||
output_format: str,
|
||||
) -> None:
|
||||
"""
|
||||
Save the benchmark results in the specified output format.
|
||||
|
||||
@ -346,18 +374,31 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
||||
for n_prompt in args.n_prompt:
|
||||
for n_gen in args.n_gen:
|
||||
# Warmup run
|
||||
_ = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args)
|
||||
_ = benchmark_performance(
|
||||
model, tokenizer, n_prompt, n_gen, **args.gen_args
|
||||
)
|
||||
# Benchmark iterations
|
||||
metrics_list = []
|
||||
for i in range(args.repetitions):
|
||||
metrics = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args)
|
||||
metrics = benchmark_performance(
|
||||
model, tokenizer, n_prompt, n_gen, **args.gen_args
|
||||
)
|
||||
metrics_list.append(metrics)
|
||||
# Compute average metrics
|
||||
avg_metrics = {}
|
||||
keys = ["prompt_tokens", "prompt_tps", "response_tokens", "response_tps", "exec_time", "ram_usage"]
|
||||
keys = [
|
||||
"prompt_tokens",
|
||||
"prompt_tps",
|
||||
"response_tokens",
|
||||
"response_tps",
|
||||
"exec_time",
|
||||
"ram_usage",
|
||||
]
|
||||
for key in keys:
|
||||
valid_values = [m[key] for m in metrics_list if m[key] is not None]
|
||||
avg_metrics[key] = sum(valid_values) / len(valid_values) if valid_values else None
|
||||
avg_metrics[key] = (
|
||||
sum(valid_values) / len(valid_values) if valid_values else None
|
||||
)
|
||||
result = {
|
||||
"Model": model_path,
|
||||
"Model Load Time (s)": round(model_load_time, 3),
|
||||
@ -367,7 +408,9 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
||||
"Response TPS": round(avg_metrics["response_tps"], 3),
|
||||
"Execution Time (s)": round(avg_metrics["exec_time"], 3),
|
||||
"Memory Usage (GB)": (
|
||||
round(avg_metrics["ram_usage"], 2) if avg_metrics["ram_usage"] is not None else None
|
||||
round(avg_metrics["ram_usage"], 2)
|
||||
if avg_metrics["ram_usage"] is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
# Print the result row immediately after each test completes
|
||||
|
Loading…
Reference in New Issue
Block a user