diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py index 538be927..2e088d63 100644 --- a/llms/mlx_lm/__init__.py +++ b/llms/mlx_lm/__init__.py @@ -7,3 +7,9 @@ from ._version import __version__ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" from .utils import convert, generate, load, stream_generate + + +def get_estimate_mem(): + from .estimate_memory import estimate_mem + + return estimate_mem diff --git a/llms/mlx_lm/estimate_memory.py b/llms/mlx_lm/estimate_memory.py new file mode 100644 index 00000000..e66ff8f0 --- /dev/null +++ b/llms/mlx_lm/estimate_memory.py @@ -0,0 +1,470 @@ +# Copyright © 2023-2025 Apple Inc. + +import argparse +import json +import math +import os +from pathlib import Path +from typing import Dict, Optional, Tuple, Type + +from huggingface_hub import hf_hub_download, try_to_load_from_cache +from mlx_lm.models.base import BaseModelArgs +from mlx_lm.utils import _get_classes + + +def fetch_metadata(model_path: str) -> Tuple[Dict, Optional[int]]: + """Fetch config.json and optionally model.safetensors.index.json for weights size.""" + config = fetch_config(model_path) + model_weight_size = None + if not os.path.isdir(model_path): + try: + index_path = hf_hub_download( + repo_id=model_path, filename="model.safetensors.index.json" + ) + with open(index_path, "r") as f: + index = json.load(f) + model_weight_size = index.get("metadata", {}).get("total_size") + except: + pass + return config, model_weight_size + + +def fetch_config(model_path: str) -> Dict: + """Fetch or load config.json without downloading the full model, checking cache first.""" + if os.path.isdir(model_path): + config_path = Path(model_path) / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"config.json not found in {model_path}") + with open(config_path, "r") as f: + return json.load(f) + else: + cached_path = try_to_load_from_cache( + repo_id=model_path, filename="config.json", repo_type="model" + ) + if cached_path and os.path.exists(cached_path): + with open(cached_path, "r") as f: + return json.load(f) + try: + config_path = hf_hub_download(repo_id=model_path, filename="config.json") + with open(config_path, "r") as f: + return json.load(f) + except Exception as e: + raise ValueError(f"Failed to fetch config.json from {model_path}: {str(e)}") + + +def compute_bits_per_weight_from_config(config: Dict) -> float: + """Infer bits-per-weight from config, defaulting to 16 (FP16) if unquantized.""" + quantization = config.get("quantization", {}) + bits = quantization.get("bits") + return float(bits) if bits is not None else 16.0 + + +def calc_embedding_params(vocab_size: int, hidden_size: int) -> int: + """Calculate parameters for the embedding layer.""" + return vocab_size * hidden_size + + +def calc_attention_params(args, hidden_size: int, num_attention_heads: int) -> int: + """Calculate parameters for one attention layer, handling standard and LoRA variants. + + This function supports both standard multi-head attention (e.g., Mixtral, OLMoE) and + LoRA-like attention (e.g., DeepSeek V3) by checking for q_lora_rank, allowing flexibility + over the older hardcoded approach that assumed uniform QKV dimensions. + """ + num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads) + head_dim = getattr(args, "head_dim", None) + if head_dim is None: + head_dim = hidden_size // num_attention_heads + has_bias = getattr(args, "attention_bias", False) + + # Standard attention (Q, K, V, O) + if not hasattr(args, "q_lora_rank") or not args.q_lora_rank: + q_params = hidden_size * (num_attention_heads * head_dim) + k_params = hidden_size * (num_kv_heads * head_dim) + v_params = k_params + o_params = (num_attention_heads * head_dim) * hidden_size + # LoRA-like attention (e.g., DeepSeek V3) + else: + q_head_dim = getattr(args, "qk_nope_head_dim", 0) + getattr( + args, "qk_rope_head_dim", head_dim + ) + v_head_dim = getattr(args, "v_head_dim", head_dim) + qk_rope_dim = getattr(args, "qk_rope_head_dim", head_dim) + q_params = hidden_size * args.q_lora_rank + args.q_lora_rank * ( + num_attention_heads * q_head_dim + ) + k_params = hidden_size * (args.kv_lora_rank + qk_rope_dim) + v_params = args.kv_lora_rank * ( + num_attention_heads * (q_head_dim - qk_rope_dim + v_head_dim) + ) + o_params = (num_attention_heads * v_head_dim) * hidden_size + + total = q_params + k_params + v_params + o_params + if has_bias: + total += ( + num_attention_heads * head_dim + num_kv_heads * head_dim * 2 + hidden_size + ) + + return total + + +def calc_ffn_or_moe_params( + args, hidden_size: int, intermediate_size: int, layer_idx: int +) -> int: + """Calculate parameters for FFN or MoE layer, switching based on config. + + Unlike the previous hardcoded FFN-only calculation, this function dynamically handles + Mixture of Experts (MoE) models like Mixtral, OLMoE, and DeepSeek V3 by detecting + expert-related fields (num_experts, n_routed_experts) and adjusting for dense vs. + MoE layers, supporting varied intermediate sizes and shared experts. + """ + num_experts = max( + getattr(args, "num_local_experts", 0), + getattr(args, "num_experts", 0), + getattr(args, "n_routed_experts", 0), + ) + moe_intermediate_size = getattr(args, "moe_intermediate_size", intermediate_size) + dense_up_to = ( + getattr(args, "first_k_dense_replace", 0) + if num_experts + else args.num_hidden_layers + ) + has_bias = getattr(args, "mlp_bias", False) + shared_experts = getattr(args, "n_shared_experts", 0) + + if num_experts and layer_idx >= dense_up_to: + # MoE: gate + expert FFNs + gate_params = hidden_size * num_experts + expert_params = ( + num_experts * hidden_size * moe_intermediate_size * 3 + ) # gate_proj, up_proj, down_proj + shared_params = ( + shared_experts * hidden_size * moe_intermediate_size * 3 + if shared_experts + else 0 + ) + return gate_params + expert_params + shared_params + else: + # Dense FFN + ffn_params = ( + hidden_size * intermediate_size * 2 + intermediate_size * hidden_size + ) + if has_bias: + ffn_params += intermediate_size * 2 + hidden_size + return ffn_params + + +def calc_norm_params(args, hidden_size: int, num_attention_heads: int) -> int: + """Calculate normalization parameters, adjusting for extra norms in complex models. + + This extends the old approach (fixed 2 norms per layer) by adding heuristic support + for extra normalization layers (e.g., OLMoE's q_norm, k_norm) in MoE or LoRA models, + improving accuracy over the simpler assumption of uniform RMSNorm usage. + """ + num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads) + head_dim = getattr(args, "head_dim", None) + if head_dim is None: + head_dim = hidden_size // num_attention_heads + + # Base: input + post-attention RMSNorm + total = hidden_size * 2 + + # Heuristic: extra norms for MoE or complex attention + has_experts = any( + getattr(args, attr, 0) > 0 + for attr in ["num_local_experts", "num_experts", "n_routed_experts"] + ) + if has_experts or hasattr(args, "q_lora_rank"): + total += (num_attention_heads * head_dim) + (num_kv_heads * head_dim) + + return total + + +def calculate_num_parameters( + config: Dict, model_args_class: Optional[Type["BaseModelArgs"]] = None +) -> int: + """Calculate the total number of parameters in a model based on its config. + + By splitting into separate functions, we now support diverse + architectures while maintaining readability and avoiding model-specific hardcoding. + """ + # Use the imported _get_classes function to get the ModelArgs class + if model_args_class is None: + _, model_args_class = _get_classes(config) + + args = model_args_class.from_dict(config) + + # Validate required fields + required = [ + "hidden_size", + "num_hidden_layers", + "vocab_size", + "num_attention_heads", + "intermediate_size", + ] + missing = [field for field in required if getattr(args, field, None) is None] + if missing: + raise ValueError(f"Config missing required fields: {missing}") + + # Core config + hidden_size = args.hidden_size + num_layers = args.num_hidden_layers + vocab_size = args.vocab_size + num_attention_heads = args.num_attention_heads + intermediate_size = args.intermediate_size + + # Total calculation + total_params = calc_embedding_params(vocab_size, hidden_size) + for layer in range(num_layers): + total_params += calc_attention_params(args, hidden_size, num_attention_heads) + total_params += calc_ffn_or_moe_params( + args, hidden_size, intermediate_size, layer + ) + total_params += calc_norm_params(args, hidden_size, num_attention_heads) + total_params += hidden_size # Final norm + if not getattr(args, "tie_word_embeddings", True): + total_params += hidden_size * vocab_size # LM head + + return total_params + + +def calculate_head_dim(config: Dict, args: BaseModelArgs) -> int: + """Infer head dimension dynamically from config or args.""" + head_dim = getattr(args, "head_dim", None) + if head_dim is None: + if "hidden_size" not in config or "num_attention_heads" not in config: + raise ValueError( + "Cannot compute head_dim: missing hidden_size or num_attention_heads" + ) + head_dim = config["hidden_size"] // config["num_attention_heads"] + return head_dim + + +def estimate_mem( + model_path: str, + context_length: int = 4096, + max_kv_size: Optional[int] = None, + kv_bits: Optional[int] = None, + kv_group_size: Optional[int] = None, + tokens_to_generate: int = 0, +) -> Tuple[Dict[str, float], str]: + """ + Estimate the memory usage of a model. + + Args: + model_path: Path to the model. + context_length: Context length of the model (prompt length in unbounded mode). + max_kv_size: Maximum size of the KV cache (for bounded mode). + kv_bits: Number of bits to use for quantized KV cache. + kv_group_size: Group size to use for quantized KV cache. + tokens_to_generate: Number of tokens to generate beyond the prompt. + + Returns: + A tuple of (results, mode) where results is a dictionary of memory usage + and mode is a string indicating the mode of the KV cache. + """ + config, model_weight_size = fetch_metadata(model_path) + bits_per_weight = compute_bits_per_weight_from_config(config) + + # Determine the model class + _, model_args_class = _get_classes(config) + args = model_args_class.from_dict(config) + + # Calculate the number of parameters + num_parameters = calculate_num_parameters(config, model_args_class) + + # Extract model architecture parameters needed for memory calculations + num_layers = args.num_hidden_layers + num_kv_heads = getattr(args, "num_key_value_heads", args.num_attention_heads) + head_dim = calculate_head_dim(config, args) + + # Default to fp16 (2 bytes per element) for KV cache unless quantized + bytes_per_element = 2 + + # If kv_bits and kv_group_size are not provided, try to read from config + if kv_bits is None or kv_group_size is None: + # Try to get quantization settings from config + quantization = config.get("quantization", {}) + quantization_config = config.get("quantization_config", {}) + + # Use the first available quantization config + quant_info = quantization or quantization_config + + if quant_info: + kv_bits = kv_bits or quant_info.get("bits") + kv_group_size = kv_group_size or quant_info.get("group_size") + + # Calculate the model weight memory usage + bytes_per_parameter = bits_per_weight / 8 + if model_weight_size: + # Use the size from safetensors index if available + model_size_gb = model_weight_size / (1024**3) + else: + # Calculate from parameter count + model_size_gb = (num_parameters * bytes_per_parameter) / (1024**3) + + # Estimate tokenizer size + vocab_size = config.get("vocab_size", args.vocab_size) + fixed_overhead_bytes = 25 * 1024 * 1024 + avg_token_size_bytes = 650 + tokenizer_size_bytes = (vocab_size * avg_token_size_bytes) + fixed_overhead_bytes + tokenizer_size_gb = tokenizer_size_bytes / (1024**3) + + # Determine the mode + mode = "Bounded" if max_kv_size else "Unbounded" + + # KV length is fixed to max_kv_size in bounded mode, or context_length in unbounded mode + kv_length = max_kv_size if mode == "Bounded" else context_length + + # Default step size from cache.py is 256 + step_size = 256 + kv_length_padded = ((kv_length + step_size - 1) // step_size) * step_size + + # Calculate KV cache size based on whether quantization is used + if kv_bits and kv_group_size: + # Quantized cache calculations + groups_per_head_dim = (head_dim + kv_group_size - 1) // kv_group_size + elements_per_int = 8 * 4 // kv_bits + + data_size = ( + num_kv_heads * kv_length_padded * (head_dim // elements_per_int) * 4 + ) / (1024**3) + quant_overhead = ( + num_kv_heads * kv_length_padded * groups_per_head_dim * 2 * 2 + ) / (1024**3) + per_layer_kv_size = 2 * (data_size + quant_overhead) + + elements_per_token = (head_dim // elements_per_int) * 4 + scales_zeros_per_token = groups_per_head_dim * 2 * 2 + per_token_bytes = ( + 2 * num_kv_heads * (elements_per_token + scales_zeros_per_token) + ) + per_token_increase = (per_token_bytes * num_layers) / (1024**3) + else: + # Standard fp16 cache + per_layer_kv_size = ( + 2 * num_kv_heads * kv_length_padded * head_dim * bytes_per_element + ) / (1024**3) + per_token_increase = ( + 2 * num_kv_heads * head_dim * bytes_per_element * num_layers + ) / (1024**3) + + total_kv_cache_size = num_layers * per_layer_kv_size + + # Add the memory for generated tokens if specified + if tokens_to_generate > 0: + total_kv_cache_size += tokens_to_generate * per_token_increase + + # For inference in MLX, estimate activation memory + activation_size_gb = 0.03 * model_size_gb + + overhead_gb = tokenizer_size_gb + activation_size_gb + (model_size_gb * 0.05) + + # Total memory usage + total_memory_gb = model_size_gb + total_kv_cache_size + overhead_gb + + results = { + "Weight": model_size_gb, + "KV Cache": total_kv_cache_size, + "Overhead": overhead_gb, + "Total": total_memory_gb, + "per_token_increase": per_token_increase, + } + + return results, mode + + +def setup_arg_parser(): + parser = argparse.ArgumentParser(description="MLX Model URAM Estimation Tool") + parser.add_argument("model", help="Local model path or Hugging Face repo ID.") + parser.add_argument( + "--context-length", + type=int, + default=4096, + help="Context length of the model (prompt length in unbounded mode).", + ) + parser.add_argument( + "--max-kv-size", type=int, help="Max KV cache size (enables bounded mode)." + ) + parser.add_argument("--kv-bits", type=int, help="Bits for KV cache quantization.") + parser.add_argument( + "--kv-group-size", type=int, help="Group size for KV cache quantization." + ) + parser.add_argument( + "--tokens-to-generate", + type=int, + default=0, + help="Number of tokens to generate beyond the prompt.", + ) + return parser + + +def print_table( + results: Dict[str, float], mode: str, tokens_to_generate: int = 0 +) -> None: + """ + Print a memory usage table in a formatted way. + + Args: + results: Dictionary containing memory usage data (in GB unless specified). + mode: Either "Bounded" or "Unbounded" to describe the KV Cache type. + tokens_to_generate: Number of tokens generated (optional, defaults to 0). + """ + # Construct the title dynamically + title = f"*Memory Usage Estimate ({mode} KV Cache" + if tokens_to_generate > 0: + title += f" after generating {tokens_to_generate:,} tokens" + title += "):*" + + # Define table formatting constants + LINE_WIDTH = 34 + ITEM_COL_WIDTH = 17 + MEMORY_COL_WIDTH = 12 + + # Print header + print(title) + print("-" * LINE_WIDTH) + print(f"| {'Item':<{ITEM_COL_WIDTH}} | {'Memory':<{MEMORY_COL_WIDTH}} |") + print("-" * LINE_WIDTH) + + # Define display order and handle missing keys gracefully + display_order = ["Weight", "KV Cache", "Overhead", "Total"] + for key in display_order: + value = results.get(key, 0.0) # Default to 0.0 if key is missing + memory_str = f"{value:.2f} GB" + + # Print row (extra spaces for alignment) + if key == "Total": + print("-" * LINE_WIDTH) + print(f"| {key:<{ITEM_COL_WIDTH}} | {memory_str:>{MEMORY_COL_WIDTH}} |") + + print("-" * LINE_WIDTH) + + # Add footer for Unbounded mode + if mode == "Unbounded" and "per_token_increase" in results: + per_token_gb = results["per_token_increase"] + if per_token_gb > 0: # Avoid division by zero + tokens_per_gb = math.floor(1 / per_token_gb) + print(f"Additional tokens per 1GB increase: {tokens_per_gb:,}") + else: + print("Note: Per-token increase is zero or invalid.") + + +def main(): + parser = setup_arg_parser() + args = parser.parse_args() + + results, mode = estimate_mem( + args.model, + args.context_length, + args.max_kv_size, + args.kv_bits, + args.kv_group_size, + args.tokens_to_generate, + ) + + print_table(results, mode, args.tokens_to_generate) + + +if __name__ == "__main__": + main() diff --git a/llms/test_memory_estimation.py b/llms/test_memory_estimation.py new file mode 100644 index 00000000..79f44fb7 --- /dev/null +++ b/llms/test_memory_estimation.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 +# Copyright © 2023-2025 Apple Inc. + +""" +Test script to validate memory usage estimation from estimate.py. +Loads a model, runs inference, and compares actual vs estimated memory usage. +""" + +import argparse +import gc +import os +import time +from typing import Dict, Optional, Tuple + +import mlx.core as mx +from mlx_lm.estimate import ( + calculate_head_dim, + compute_bits_per_weight_from_config, + estimate_uram, + fetch_config, +) +from mlx_lm.utils import generate, load, stream_generate + + +def setup_arg_parser(): + parser = argparse.ArgumentParser(description="Test Memory Estimation Accuracy") + parser.add_argument("model", help="Local model path or Hugging Face repo ID.") + parser.add_argument( + "--prompt", type=str, default="Once upon a time", help="Prompt for inference." + ) + parser.add_argument( + "--prompt-file", + type=str, + default="prompt.txt", + help="File containing the prompt.", + ) + parser.add_argument( + "--num-tokens", type=int, default=50, help="Number of tokens to generate." + ) + parser.add_argument("--kv-bits", type=int, help="Bits for KV cache quantization.") + parser.add_argument( + "--kv-group-size", + type=int, + default=64, + help="Group size for KV cache quantization.", + ) + parser.add_argument( + "--max-kv-size", type=int, help="Max KV cache size (bounded mode)." + ) + parser.add_argument( + "--context-length", + type=int, + default=4096, + help="Context length for estimation.", + ) + parser.add_argument( + "--verbose", action="store_true", help="Enable verbose logging." + ) + return parser + + +def get_memory_usage(): + """Get current memory usage in GB.""" + used_gb = mx.metal.get_peak_memory() / (1024**3) + mx.metal.reset_peak_memory() + return used_gb + + +def get_active_memory_gb(): + """Get current active memory usage in GB.""" + return mx.metal.get_active_memory() / (1024**3) + + +def force_gc_and_reset(): + """Force garbage collection and reset memory counters.""" + gc.collect() + mx.metal.reset_peak_memory() + time.sleep(0.1) # Small delay to ensure memory operations complete + + +def measure_kv_cache_memory( + model_pkg, tokenizer, prompt_tokens, num_new_tokens=10, verbose=False +): + """ + Directly measure KV cache memory by comparing memory before and after cache creation. + + Args: + model_pkg: The loaded model package (contains model and generate function) + tokenizer: Tokenizer for the model + prompt_tokens: Tokenized prompt + num_new_tokens: Number of new tokens to generate + verbose: Whether to print verbose output + + Returns: + Tuple of (kv_cache_size_gb, kv_per_token_gb) + """ + # Force clean memory state + force_gc_and_reset() + + # Get baseline memory + baseline = get_active_memory_gb() + if verbose: + print(f"Baseline memory before KV cache: {baseline:.4f} GB") + + # Create inputs + inputs = mx.array([prompt_tokens]) + + # First measure memory with just the prompt - no generation + # Just do a forward pass to build the KV cache + logits = model_pkg.model(inputs) + mx.eval(logits) + + # Measure memory with just prompt in KV cache + prompt_kv_memory = get_active_memory_gb() - baseline + if verbose: + print( + f"Memory after prompt KV cache: {prompt_kv_memory:.4f} GB for {len(prompt_tokens)} tokens" + ) + + # Reset for generation test + force_gc_and_reset() + baseline_with_prompt_kv = get_active_memory_gb() + + # For generation, we need to use the mlx_lm generate function, not model.generate + # Create a simple manual generation loop to measure memory impact + input_ids = mx.array([prompt_tokens]) + + # Generate tokens one by one to measure KV cache growth + for _ in range(num_new_tokens): + # Forward pass + logits = model_pkg.model(input_ids) + next_token_logits = logits[0, -1, :] + next_token = mx.argmax(next_token_logits) + input_ids = mx.concatenate([input_ids, mx.array([[next_token]])], axis=1) + mx.eval(input_ids) + + # Measure memory after generation including KV cache + total_kv_memory = ( + get_active_memory_gb() - baseline_with_prompt_kv + prompt_kv_memory + ) + + # Calculate per-token KV cache size + total_tokens = len(prompt_tokens) + num_new_tokens + if num_new_tokens > 0: + per_token_gb = (total_kv_memory - prompt_kv_memory) / num_new_tokens + else: + per_token_gb = 0 + + if verbose: + print( + f"Total KV cache memory: {total_kv_memory:.4f} GB for {total_tokens} tokens" + ) + print(f"Measured KV cache per token: {per_token_gb:.8f} GB") + + return total_kv_memory, per_token_gb + + +def main(): + parser = setup_arg_parser() + args = parser.parse_args() + verbose = args.verbose + + # Measure baseline memory before loading model + force_gc_and_reset() + baseline_memory_gb = get_memory_usage() + print(f"Baseline memory usage: {baseline_memory_gb:.4f} GB") + + print(f"Loading model from {args.model}...") + + # Clear memory before starting + force_gc_and_reset() + + # Load the model + model_pkg, tokenizer = load(args.model) + + # Memory after loading model but before materializing parameters + model_load_memory_gb = get_memory_usage() + + # Materialize model parameters to get accurate parameter memory usage + print("Materializing model parameters...") + force_gc_and_reset() + + # Force materialization of all model parameters + for p in model_pkg.model.parameters().values(): + mx.eval(p) + + # Get memory used by materialized parameters + parameter_memory_gb = get_active_memory_gb() + print(f"Model parameters memory: {parameter_memory_gb:.4f} GB") + + # Try to read prompt from file if it exists, otherwise use default + prompt = args.prompt + try: + if os.path.exists(args.prompt_file): + with open(args.prompt_file, "r") as f: + file_prompt = f.read().strip() + if file_prompt: + prompt = file_prompt + print( + f"Using prompt from {args.prompt_file} ({len(prompt)} characters)" + ) + else: + print(f"Empty prompt file, using default prompt") + except Exception as e: + print(f"Error reading prompt file: {e}, using default prompt") + + # Count tokens in the prompt + tokens = tokenizer.encode(prompt) + prompt_token_count = len(tokens) + print(f"Prompt length: {prompt_token_count} tokens") + + # Run inference with memory tracking + print(f"Running inference with prompt...") + + # Reset memory state + force_gc_and_reset() + + # First do a simple forward pass to measure activation memory + print("Running single forward pass...") + + # Tokenize input + inputs = mx.array([tokens]) + + # Create an evaluation function that we will trace + def forward_pass(): + return model_pkg.model(inputs) + + # Trace the function to capture the actual memory during execution + output = forward_pass() + mx.eval(output) + + # Get memory used during forward pass + forward_memory_gb = get_memory_usage() + print(f"Peak memory during forward pass: {forward_memory_gb:.4f} GB") + + # Get model config for additional details + config = fetch_config(args.model) + bits_per_weight = compute_bits_per_weight_from_config(config) + + # Get the necessary parameters for KV cache calculation from the model config + from mlx_lm.utils import _get_classes + + _, model_args_class = _get_classes(config) + model_args = model_args_class.from_dict(config) + + # Extract the parameters needed for KV cache calculation + head_dim = calculate_head_dim(config, model_args) + num_kv_heads = getattr( + model_args, "num_key_value_heads", model_args.num_attention_heads + ) + num_layers = model_args.num_hidden_layers + + # Now directly measure the KV cache memory + print("\nDirectly measuring KV cache memory usage...") + actual_kv_cache_gb, actual_per_token_gb = measure_kv_cache_memory( + model_pkg, + tokenizer, + tokens, + num_new_tokens=min(20, args.num_tokens), + verbose=verbose, + ) + + # Measure memory during token generation (inference) using proper generate function + print("\nMeasuring memory during full token generation (streaming)...") + force_gc_and_reset() + baseline_for_generation = get_active_memory_gb() + + # Use stream_generate to get token-by-token memory measurements + generation_text = "" + peak_memory_gb = 0 + total_tokens_generated = 0 + token_memory_profile = [] + + # Stream generation and track memory for each token + for response in stream_generate( + model_pkg.model, tokenizer, prompt, max_tokens=args.num_tokens + ): + generation_text += response.text + total_tokens_generated += 1 + + # Track memory per token + current_memory = response.peak_memory + peak_memory_gb = max(peak_memory_gb, current_memory) + + # Record memory for this token + token_memory_profile.append(current_memory) + + if verbose and total_tokens_generated % 10 == 0: + print( + f"Generated {total_tokens_generated} tokens, current memory: {current_memory:.4f} GB" + ) + + # Calculate final memory usage + generation_memory_gb = peak_memory_gb + print( + f"Peak memory during generation of {total_tokens_generated} tokens: {generation_memory_gb:.4f} GB" + ) + + # You can also add this to get more detailed memory profile analysis + if verbose: + print("\nMemory growth during generation:") + for i, mem in enumerate(token_memory_profile): + if i % 5 == 0 or i == len(token_memory_profile) - 1: + print(f" Token {i+1}: {mem:.4f} GB") + + # Calculate activation memory (peak memory during generation minus parameter memory and KV cache) + actual_activation_memory_gb = ( + generation_memory_gb - parameter_memory_gb - actual_kv_cache_gb + ) + if actual_activation_memory_gb < 0: + # This can happen due to memory reclamation between measurements + actual_activation_memory_gb = ( + 0.01 * parameter_memory_gb + ) # Use a reasonable fallback + + # Get estimated memory usage from estimate.py + estimated_results, mode = estimate_uram( + args.model, + context_length=args.context_length, + max_kv_size=args.max_kv_size, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + initial_prompt_length=prompt_token_count, + extra_tokens=total_tokens_generated, + ) + + # Compare estimation accuracy + print("\n--- MEMORY ESTIMATION ACCURACY ---") + print(f"Model: {args.model}") + print(f"Architecture: {config.get('model_type', 'unknown')}") + print(f"Quantization: {bits_per_weight} bits per weight") + print(f"Total tokens processed: {prompt_token_count + total_tokens_generated}") + print("-" * 40) + print(f"Actual model parameters memory: {parameter_memory_gb:.3g} GB") + print(f"Estimated model parameters memory: {estimated_results['Model']:.3g} GB") + print( + f"Model memory error: {abs(parameter_memory_gb - estimated_results['Model']):.3g} GB" + ) + print("-" * 40) + print(f"Actual KV cache memory: {actual_kv_cache_gb:.3g} GB") + print(f"Estimated KV cache memory: {estimated_results['KV Cache']:.3g} GB") + print( + f"KV cache memory error: {abs(actual_kv_cache_gb - estimated_results['KV Cache']):.3g} GB" + ) + print("-" * 40) + print(f"Actual per-token KV increase: {actual_per_token_gb:.6g} GB") + print( + f"Estimated per-token KV increase: {estimated_results['per_token_increase']:.6g} GB" + ) + print( + f"Per-token KV error: {abs(actual_per_token_gb - estimated_results['per_token_increase']):.6g} GB" + ) + print("-" * 40) + print(f"Actual activation memory: {actual_activation_memory_gb:.3g} GB") + print(f"Estimated activation memory: {estimated_results['Activations']:.3g} GB") + print( + f"Activation memory error: {abs(actual_activation_memory_gb - estimated_results['Activations']):.3g} GB" + ) + print("-" * 40) + print(f"Total peak memory (actual): {generation_memory_gb:.3g} GB") + print(f"Total memory (estimated): {estimated_results['Total']:.3g} GB") + print( + f"Total memory error: {abs(generation_memory_gb - estimated_results['Total']):.3g} GB" + ) + print(f"Mode: {mode}") + + +if __name__ == "__main__": + main()