From f5cd9997746195c7a5bd30cf471c33ed0b6027a8 Mon Sep 17 00:00:00 2001 From: B1ACK917 Date: Wed, 12 Mar 2025 23:10:58 +0900 Subject: [PATCH] update: pre-commit format hook --- llms/mlx_lm/bench.py | 75 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 16 deletions(-) diff --git a/llms/mlx_lm/bench.py b/llms/mlx_lm/bench.py index d980eca0..20009c1a 100644 --- a/llms/mlx_lm/bench.py +++ b/llms/mlx_lm/bench.py @@ -1,3 +1,8 @@ +# @Author : Dawei Feng +# @Time : 2025/3/12 22:00 +# @File : bench.py +# @Email : darkv.feng@outlook.com + """ MLX-LM Benchmark Tool @@ -14,19 +19,21 @@ It supports multiple input values for model, prompt tokens, and generation token """ import argparse +import contextlib +import csv +import io import json import logging import random import re -import io -import contextlib -import csv import time from typing import Any, Dict, List, Optional, Union -from .utils import load, generate +from .utils import generate, load -logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +logging.basicConfig( + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" +) class CommaSeparatedIntegers(argparse.Action): @@ -88,7 +95,11 @@ def parse_args() -> argparse.Namespace: help="Outout Sequence Length (OSL). Number of tokens to generate. Accepts multiple comma-separated values.", ) parser.add_argument( - "-r", "--repetitions", type=int, default=5, help="Number of benchmark repetitions to average results over." + "-r", + "--repetitions", + type=int, + default=5, + help="Number of benchmark repetitions to average results over.", ) parser.add_argument( "-o", @@ -168,7 +179,9 @@ def generate_synthetic_tokens(tokenizer: Any, seq_length: int) -> List[int]: # Prepend BOS token if available; otherwise, start with an empty list. tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else [] - tokens += [random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens))] + tokens += [ + random.randint(0, vocab_size - 1) for _ in range(seq_length - len(tokens)) + ] return tokens @@ -199,13 +212,17 @@ def parse_metrics(log_output: str) -> Dict[str, Optional[float]]: } # Extract prompt tokens and tokens-per-second - prompt_match = re.search(r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output) + prompt_match = re.search( + r"Prompt:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output + ) if prompt_match: metrics["prompt_tokens"] = int(prompt_match.group(1)) metrics["prompt_tps"] = float(prompt_match.group(2)) # Extract generation tokens and tokens-per-second - generation_match = re.search(r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output) + generation_match = re.search( + r"Generation:\s*(\d+)\s*tokens,\s*([\d.]+)\s*tokens-per-sec", log_output + ) if generation_match: metrics["response_tokens"] = int(generation_match.group(1)) metrics["response_tps"] = float(generation_match.group(2)) @@ -248,12 +265,23 @@ def benchmark_performance( input_tokens = generate_synthetic_tokens(tokenizer, seq_length) output_buffer = io.StringIO() with contextlib.redirect_stdout(output_buffer): - generate(model, tokenizer, input_tokens, max_tokens=max_tokens, verbose=True, **generate_kwargs) + generate( + model, + tokenizer, + input_tokens, + max_tokens=max_tokens, + verbose=True, + **generate_kwargs, + ) captured_output = output_buffer.getvalue() return parse_metrics(captured_output) -def save_results(output_file, results: Union[Dict[str, Any], List[Dict[str, Any]]], output_format: str) -> None: +def save_results( + output_file, + results: Union[Dict[str, Any], List[Dict[str, Any]]], + output_format: str, +) -> None: """ Save the benchmark results in the specified output format. @@ -346,18 +374,31 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]: for n_prompt in args.n_prompt: for n_gen in args.n_gen: # Warmup run - _ = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args) + _ = benchmark_performance( + model, tokenizer, n_prompt, n_gen, **args.gen_args + ) # Benchmark iterations metrics_list = [] for i in range(args.repetitions): - metrics = benchmark_performance(model, tokenizer, n_prompt, n_gen, **args.gen_args) + metrics = benchmark_performance( + model, tokenizer, n_prompt, n_gen, **args.gen_args + ) metrics_list.append(metrics) # Compute average metrics avg_metrics = {} - keys = ["prompt_tokens", "prompt_tps", "response_tokens", "response_tps", "exec_time", "ram_usage"] + keys = [ + "prompt_tokens", + "prompt_tps", + "response_tokens", + "response_tps", + "exec_time", + "ram_usage", + ] for key in keys: valid_values = [m[key] for m in metrics_list if m[key] is not None] - avg_metrics[key] = sum(valid_values) / len(valid_values) if valid_values else None + avg_metrics[key] = ( + sum(valid_values) / len(valid_values) if valid_values else None + ) result = { "Model": model_path, "Model Load Time (s)": round(model_load_time, 3), @@ -367,7 +408,9 @@ def run_benchmarks(args: argparse.Namespace) -> List[Dict[str, Any]]: "Response TPS": round(avg_metrics["response_tps"], 3), "Execution Time (s)": round(avg_metrics["exec_time"], 3), "Memory Usage (GB)": ( - round(avg_metrics["ram_usage"], 2) if avg_metrics["ram_usage"] is not None else None + round(avg_metrics["ram_usage"], 2) + if avg_metrics["ram_usage"] is not None + else None ), } # Print the result row immediately after each test completes