mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
update: pre-commit format hook
This commit is contained in:
parent
346d9641d7
commit
f5cd999774
@ -1,3 +1,8 @@
|
|||||||
|
# @Author : Dawei Feng
|
||||||
|
# @Time : 2025/3/12 22:00
|
||||||
|
# @File : bench.py
|
||||||
|
# @Email : darkv.feng@outlook.com
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MLX-LM Benchmark Tool
|
MLX-LM Benchmark Tool
|
||||||
|
|
||||||
@ -14,19 +19,21 @@ It supports multiple input values for model, prompt tokens, and generation token
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import io
|
|
||||||
import contextlib
|
|
||||||
import csv
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from .utils import load, generate
|
from .utils import generate, load
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CommaSeparatedIntegers(argparse.Action):
|
class CommaSeparatedIntegers(argparse.Action):
|
||||||
@ -88,7 +95,11 @@ def parse_args() -> argparse.Namespace:
|
|||||||
help="Outout Sequence Length (OSL). Number of tokens to generate. Accepts multiple comma-separated values.",
|
help="Outout Sequence Length (OSL). Number of tokens to generate. Accepts multiple comma-separated values.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-r", "--repetitions", type=int, default=5, help="Number of benchmark repetitions to average results over."
|
"-r",
|
||||||
|
"--repetitions",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of benchmark repetitions to average results over.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-o",
|
"-o",
|
||||||
@ -168,7 +179,9 @@ def generate_synthetic_tokens(tokenizer: Any, seq_length: int) -> List[int]:
|
|||||||
|
|
||||||
# Prepend BOS token if available; otherwise, start with an empty list.
|
# Prepend BOS token if available; otherwise, start with an empty list.
|
||||||
tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
|
tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
|
||||||
tokens += [random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens))]
|
tokens += [
|
||||||
|
random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens))
|
||||||
|
]
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@ -199,13 +212,17 @@ def parse_metrics(log_output: str) -> Dict[str, Optional[float]]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Extract prompt tokens and tokens-per-second
|
# Extract prompt tokens and tokens-per-second
|
||||||
prompt_match = re.search(r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output)
|
prompt_match = re.search(
|
||||||
|
r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output
|
||||||
|
)
|
||||||
if prompt_match:
|
if prompt_match:
|
||||||
metrics["prompt_tokens"] = int(prompt_match.group(1))
|
metrics["prompt_tokens"] = int(prompt_match.group(1))
|
||||||
metrics["prompt_tps"] = float(prompt_match.group(2))
|
metrics["prompt_tps"] = float(prompt_match.group(2))
|
||||||
|
|
||||||
# Extract generation tokens and tokens-per-second
|
# Extract generation tokens and tokens-per-second
|
||||||
generation_match = re.search(r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output)
|
generation_match = re.search(
|
||||||
|
r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output
|
||||||
|
)
|
||||||
if generation_match:
|
if generation_match:
|
||||||
metrics["response_tokens"] = int(generation_match.group(1))
|
metrics["response_tokens"] = int(generation_match.group(1))
|
||||||
metrics["response_tps"] = float(generation_match.group(2))
|
metrics["response_tps"] = float(generation_match.group(2))
|
||||||
@ -248,12 +265,23 @@ def benchmark_performance(
|
|||||||
input_tokens = generate_synthetic_tokens(tokenizer, seq_length)
|
input_tokens = generate_synthetic_tokens(tokenizer, seq_length)
|
||||||
output_buffer = io.StringIO()
|
output_buffer = io.StringIO()
|
||||||
with contextlib.redirect_stdout(output_buffer):
|
with contextlib.redirect_stdout(output_buffer):
|
||||||
generate(model, tokenizer, input_tokens, max_tokens=max_tokens, verbose=True, **generate_kwargs)
|
generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
input_tokens,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
verbose=True,
|
||||||
|
**generate_kwargs,
|
||||||
|
)
|
||||||
captured_output = output_buffer.getvalue()
|
captured_output = output_buffer.getvalue()
|
||||||
return parse_metrics(captured_output)
|
return parse_metrics(captured_output)
|
||||||
|
|
||||||
|
|
||||||
def save_results(output_file, results: Union[Dict[str, Any], List[Dict[str, Any]]], output_format: str) -> None:
|
def save_results(
|
||||||
|
output_file,
|
||||||
|
results: Union[Dict[str, Any], List[Dict[str, Any]]],
|
||||||
|
output_format: str,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save the benchmark results in the specified output format.
|
Save the benchmark results in the specified output format.
|
||||||
|
|
||||||
@ -346,18 +374,31 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
|||||||
for n_prompt in args.n_prompt:
|
for n_prompt in args.n_prompt:
|
||||||
for n_gen in args.n_gen:
|
for n_gen in args.n_gen:
|
||||||
# Warmup run
|
# Warmup run
|
||||||
_ = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args)
|
_ = benchmark_performance(
|
||||||
|
model, tokenizer, n_prompt, n_gen, **args.gen_args
|
||||||
|
)
|
||||||
# Benchmark iterations
|
# Benchmark iterations
|
||||||
metrics_list = []
|
metrics_list = []
|
||||||
for i in range(args.repetitions):
|
for i in range(args.repetitions):
|
||||||
metrics = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args)
|
metrics = benchmark_performance(
|
||||||
|
model, tokenizer, n_prompt, n_gen, **args.gen_args
|
||||||
|
)
|
||||||
metrics_list.append(metrics)
|
metrics_list.append(metrics)
|
||||||
# Compute average metrics
|
# Compute average metrics
|
||||||
avg_metrics = {}
|
avg_metrics = {}
|
||||||
keys = ["prompt_tokens", "prompt_tps", "response_tokens", "response_tps", "exec_time", "ram_usage"]
|
keys = [
|
||||||
|
"prompt_tokens",
|
||||||
|
"prompt_tps",
|
||||||
|
"response_tokens",
|
||||||
|
"response_tps",
|
||||||
|
"exec_time",
|
||||||
|
"ram_usage",
|
||||||
|
]
|
||||||
for key in keys:
|
for key in keys:
|
||||||
valid_values = [m[key] for m in metrics_list if m[key] is not None]
|
valid_values = [m[key] for m in metrics_list if m[key] is not None]
|
||||||
avg_metrics[key] = sum(valid_values) / len(valid_values) if valid_values else None
|
avg_metrics[key] = (
|
||||||
|
sum(valid_values) / len(valid_values) if valid_values else None
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
"Model": model_path,
|
"Model": model_path,
|
||||||
"Model Load Time (s)": round(model_load_time, 3),
|
"Model Load Time (s)": round(model_load_time, 3),
|
||||||
@ -367,7 +408,9 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
|||||||
"Response TPS": round(avg_metrics["response_tps"], 3),
|
"Response TPS": round(avg_metrics["response_tps"], 3),
|
||||||
"Execution Time (s)": round(avg_metrics["exec_time"], 3),
|
"Execution Time (s)": round(avg_metrics["exec_time"], 3),
|
||||||
"Memory Usage (GB)": (
|
"Memory Usage (GB)": (
|
||||||
round(avg_metrics["ram_usage"], 2) if avg_metrics["ram_usage"] is not None else None
|
round(avg_metrics["ram_usage"], 2)
|
||||||
|
if avg_metrics["ram_usage"] is not None
|
||||||
|
else None
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
# Print the result row immediately after each test completes
|
# Print the result row immediately after each test completes
|
||||||
|
Loading…
Reference in New Issue
Block a user