Add memory estimation tool for MLX language models

This commit introduces a comprehensive memory estimation utility for MLX language models, supporting:
- Dynamic parameter calculation across diverse model architectures
- Handling of quantized and standard models
- Estimation of model weights, KV cache, and overhead memory
- Support for bounded and unbounded KV cache modes
- Flexible configuration via command-line arguments

The new tool provides detailed memory usage insights for different model configurations and generation scenarios.
This commit is contained in:
Cavit Erginsoy 2025-03-10 02:59:09 +00:00
parent 877d2a345b
commit 7ee76a32a4
3 changed files with 845 additions and 0 deletions

View File

@ -7,3 +7,9 @@ from ._version import __version__
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
from .utils import convert, generate, load, stream_generate
def get_estimate_mem():
from .estimate_memory import estimate_mem
return estimate_mem

View File

@ -0,0 +1,470 @@
# Copyright © 2023-2025 Apple Inc.
import argparse
import json
import math
import os
from pathlib import Path
from typing import Dict, Optional, Tuple, Type
from huggingface_hub import hf_hub_download, try_to_load_from_cache
from mlx_lm.models.base import BaseModelArgs
from mlx_lm.utils import _get_classes
def fetch_metadata(model_path: str) -> Tuple[Dict, Optional[int]]:
"""Fetch config.json and optionally model.safetensors.index.json for weights size."""
config = fetch_config(model_path)
model_weight_size = None
if not os.path.isdir(model_path):
try:
index_path = hf_hub_download(
repo_id=model_path, filename="model.safetensors.index.json"
)
with open(index_path, "r") as f:
index = json.load(f)
model_weight_size = index.get("metadata", {}).get("total_size")
except:
pass
return config, model_weight_size
def fetch_config(model_path: str) -> Dict:
"""Fetch or load config.json without downloading the full model, checking cache first."""
if os.path.isdir(model_path):
config_path = Path(model_path) / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"config.json not found in {model_path}")
with open(config_path, "r") as f:
return json.load(f)
else:
cached_path = try_to_load_from_cache(
repo_id=model_path, filename="config.json", repo_type="model"
)
if cached_path and os.path.exists(cached_path):
with open(cached_path, "r") as f:
return json.load(f)
try:
config_path = hf_hub_download(repo_id=model_path, filename="config.json")
with open(config_path, "r") as f:
return json.load(f)
except Exception as e:
raise ValueError(f"Failed to fetch config.json from {model_path}: {str(e)}")
def compute_bits_per_weight_from_config(config: Dict) -> float:
"""Infer bits-per-weight from config, defaulting to 16 (FP16) if unquantized."""
quantization = config.get("quantization", {})
bits = quantization.get("bits")
return float(bits) if bits is not None else 16.0
def calc_embedding_params(vocab_size: int, hidden_size: int) -> int:
"""Calculate parameters for the embedding layer."""
return vocab_size * hidden_size
def calc_attention_params(args, hidden_size: int, num_attention_heads: int) -> int:
"""Calculate parameters for one attention layer, handling standard and LoRA variants.
This function supports both standard multi-head attention (e.g., Mixtral, OLMoE) and
LoRA-like attention (e.g., DeepSeek V3) by checking for q_lora_rank, allowing flexibility
over the older hardcoded approach that assumed uniform QKV dimensions.
"""
num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads)
head_dim = getattr(args, "head_dim", None)
if head_dim is None:
head_dim = hidden_size // num_attention_heads
has_bias = getattr(args, "attention_bias", False)
# Standard attention (Q, K, V, O)
if not hasattr(args, "q_lora_rank") or not args.q_lora_rank:
q_params = hidden_size * (num_attention_heads * head_dim)
k_params = hidden_size * (num_kv_heads * head_dim)
v_params = k_params
o_params = (num_attention_heads * head_dim) * hidden_size
# LoRA-like attention (e.g., DeepSeek V3)
else:
q_head_dim = getattr(args, "qk_nope_head_dim", 0) + getattr(
args, "qk_rope_head_dim", head_dim
)
v_head_dim = getattr(args, "v_head_dim", head_dim)
qk_rope_dim = getattr(args, "qk_rope_head_dim", head_dim)
q_params = hidden_size * args.q_lora_rank + args.q_lora_rank * (
num_attention_heads * q_head_dim
)
k_params = hidden_size * (args.kv_lora_rank + qk_rope_dim)
v_params = args.kv_lora_rank * (
num_attention_heads * (q_head_dim - qk_rope_dim + v_head_dim)
)
o_params = (num_attention_heads * v_head_dim) * hidden_size
total = q_params + k_params + v_params + o_params
if has_bias:
total += (
num_attention_heads * head_dim + num_kv_heads * head_dim * 2 + hidden_size
)
return total
def calc_ffn_or_moe_params(
args, hidden_size: int, intermediate_size: int, layer_idx: int
) -> int:
"""Calculate parameters for FFN or MoE layer, switching based on config.
Unlike the previous hardcoded FFN-only calculation, this function dynamically handles
Mixture of Experts (MoE) models like Mixtral, OLMoE, and DeepSeek V3 by detecting
expert-related fields (num_experts, n_routed_experts) and adjusting for dense vs.
MoE layers, supporting varied intermediate sizes and shared experts.
"""
num_experts = max(
getattr(args, "num_local_experts", 0),
getattr(args, "num_experts", 0),
getattr(args, "n_routed_experts", 0),
)
moe_intermediate_size = getattr(args, "moe_intermediate_size", intermediate_size)
dense_up_to = (
getattr(args, "first_k_dense_replace", 0)
if num_experts
else args.num_hidden_layers
)
has_bias = getattr(args, "mlp_bias", False)
shared_experts = getattr(args, "n_shared_experts", 0)
if num_experts and layer_idx >= dense_up_to:
# MoE: gate + expert FFNs
gate_params = hidden_size * num_experts
expert_params = (
num_experts * hidden_size * moe_intermediate_size * 3
) # gate_proj, up_proj, down_proj
shared_params = (
shared_experts * hidden_size * moe_intermediate_size * 3
if shared_experts
else 0
)
return gate_params + expert_params + shared_params
else:
# Dense FFN
ffn_params = (
hidden_size * intermediate_size * 2 + intermediate_size * hidden_size
)
if has_bias:
ffn_params += intermediate_size * 2 + hidden_size
return ffn_params
def calc_norm_params(args, hidden_size: int, num_attention_heads: int) -> int:
"""Calculate normalization parameters, adjusting for extra norms in complex models.
This extends the old approach (fixed 2 norms per layer) by adding heuristic support
for extra normalization layers (e.g., OLMoE's q_norm, k_norm) in MoE or LoRA models,
improving accuracy over the simpler assumption of uniform RMSNorm usage.
"""
num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads)
head_dim = getattr(args, "head_dim", None)
if head_dim is None:
head_dim = hidden_size // num_attention_heads
# Base: input + post-attention RMSNorm
total = hidden_size * 2
# Heuristic: extra norms for MoE or complex attention
has_experts = any(
getattr(args, attr, 0) > 0
for attr in ["num_local_experts", "num_experts", "n_routed_experts"]
)
if has_experts or hasattr(args, "q_lora_rank"):
total += (num_attention_heads * head_dim) + (num_kv_heads * head_dim)
return total
def calculate_num_parameters(
config: Dict, model_args_class: Optional[Type["BaseModelArgs"]] = None
) -> int:
"""Calculate the total number of parameters in a model based on its config.
By splitting into separate functions, we now support diverse
architectures while maintaining readability and avoiding model-specific hardcoding.
"""
# Use the imported _get_classes function to get the ModelArgs class
if model_args_class is None:
_, model_args_class = _get_classes(config)
args = model_args_class.from_dict(config)
# Validate required fields
required = [
"hidden_size",
"num_hidden_layers",
"vocab_size",
"num_attention_heads",
"intermediate_size",
]
missing = [field for field in required if getattr(args, field, None) is None]
if missing:
raise ValueError(f"Config missing required fields: {missing}")
# Core config
hidden_size = args.hidden_size
num_layers = args.num_hidden_layers
vocab_size = args.vocab_size
num_attention_heads = args.num_attention_heads
intermediate_size = args.intermediate_size
# Total calculation
total_params = calc_embedding_params(vocab_size, hidden_size)
for layer in range(num_layers):
total_params += calc_attention_params(args, hidden_size, num_attention_heads)
total_params += calc_ffn_or_moe_params(
args, hidden_size, intermediate_size, layer
)
total_params += calc_norm_params(args, hidden_size, num_attention_heads)
total_params += hidden_size # Final norm
if not getattr(args, "tie_word_embeddings", True):
total_params += hidden_size * vocab_size # LM head
return total_params
def calculate_head_dim(config: Dict, args: BaseModelArgs) -> int:
"""Infer head dimension dynamically from config or args."""
head_dim = getattr(args, "head_dim", None)
if head_dim is None:
if "hidden_size" not in config or "num_attention_heads" not in config:
raise ValueError(
"Cannot compute head_dim: missing hidden_size or num_attention_heads"
)
head_dim = config["hidden_size"] // config["num_attention_heads"]
return head_dim
def estimate_mem(
model_path: str,
context_length: int = 4096,
max_kv_size: Optional[int] = None,
kv_bits: Optional[int] = None,
kv_group_size: Optional[int] = None,
tokens_to_generate: int = 0,
) -> Tuple[Dict[str, float], str]:
"""
Estimate the memory usage of a model.
Args:
model_path: Path to the model.
context_length: Context length of the model (prompt length in unbounded mode).
max_kv_size: Maximum size of the KV cache (for bounded mode).
kv_bits: Number of bits to use for quantized KV cache.
kv_group_size: Group size to use for quantized KV cache.
tokens_to_generate: Number of tokens to generate beyond the prompt.
Returns:
A tuple of (results, mode) where results is a dictionary of memory usage
and mode is a string indicating the mode of the KV cache.
"""
config, model_weight_size = fetch_metadata(model_path)
bits_per_weight = compute_bits_per_weight_from_config(config)
# Determine the model class
_, model_args_class = _get_classes(config)
args = model_args_class.from_dict(config)
# Calculate the number of parameters
num_parameters = calculate_num_parameters(config, model_args_class)
# Extract model architecture parameters needed for memory calculations
num_layers = args.num_hidden_layers
num_kv_heads = getattr(args, "num_key_value_heads", args.num_attention_heads)
head_dim = calculate_head_dim(config, args)
# Default to fp16 (2 bytes per element) for KV cache unless quantized
bytes_per_element = 2
# If kv_bits and kv_group_size are not provided, try to read from config
if kv_bits is None or kv_group_size is None:
# Try to get quantization settings from config
quantization = config.get("quantization", {})
quantization_config = config.get("quantization_config", {})
# Use the first available quantization config
quant_info = quantization or quantization_config
if quant_info:
kv_bits = kv_bits or quant_info.get("bits")
kv_group_size = kv_group_size or quant_info.get("group_size")
# Calculate the model weight memory usage
bytes_per_parameter = bits_per_weight / 8
if model_weight_size:
# Use the size from safetensors index if available
model_size_gb = model_weight_size / (1024**3)
else:
# Calculate from parameter count
model_size_gb = (num_parameters * bytes_per_parameter) / (1024**3)
# Estimate tokenizer size
vocab_size = config.get("vocab_size", args.vocab_size)
fixed_overhead_bytes = 25 * 1024 * 1024
avg_token_size_bytes = 650
tokenizer_size_bytes = (vocab_size * avg_token_size_bytes) + fixed_overhead_bytes
tokenizer_size_gb = tokenizer_size_bytes / (1024**3)
# Determine the mode
mode = "Bounded" if max_kv_size else "Unbounded"
# KV length is fixed to max_kv_size in bounded mode, or context_length in unbounded mode
kv_length = max_kv_size if mode == "Bounded" else context_length
# Default step size from cache.py is 256
step_size = 256
kv_length_padded = ((kv_length + step_size - 1) // step_size) * step_size
# Calculate KV cache size based on whether quantization is used
if kv_bits and kv_group_size:
# Quantized cache calculations
groups_per_head_dim = (head_dim + kv_group_size - 1) // kv_group_size
elements_per_int = 8 * 4 // kv_bits
data_size = (
num_kv_heads * kv_length_padded * (head_dim // elements_per_int) * 4
) / (1024**3)
quant_overhead = (
num_kv_heads * kv_length_padded * groups_per_head_dim * 2 * 2
) / (1024**3)
per_layer_kv_size = 2 * (data_size + quant_overhead)
elements_per_token = (head_dim // elements_per_int) * 4
scales_zeros_per_token = groups_per_head_dim * 2 * 2
per_token_bytes = (
2 * num_kv_heads * (elements_per_token + scales_zeros_per_token)
)
per_token_increase = (per_token_bytes * num_layers) / (1024**3)
else:
# Standard fp16 cache
per_layer_kv_size = (
2 * num_kv_heads * kv_length_padded * head_dim * bytes_per_element
) / (1024**3)
per_token_increase = (
2 * num_kv_heads * head_dim * bytes_per_element * num_layers
) / (1024**3)
total_kv_cache_size = num_layers * per_layer_kv_size
# Add the memory for generated tokens if specified
if tokens_to_generate > 0:
total_kv_cache_size += tokens_to_generate * per_token_increase
# For inference in MLX, estimate activation memory
activation_size_gb = 0.03 * model_size_gb
overhead_gb = tokenizer_size_gb + activation_size_gb + (model_size_gb * 0.05)
# Total memory usage
total_memory_gb = model_size_gb + total_kv_cache_size + overhead_gb
results = {
"Weight": model_size_gb,
"KV Cache": total_kv_cache_size,
"Overhead": overhead_gb,
"Total": total_memory_gb,
"per_token_increase": per_token_increase,
}
return results, mode
def setup_arg_parser():
parser = argparse.ArgumentParser(description="MLX Model URAM Estimation Tool")
parser.add_argument("model", help="Local model path or Hugging Face repo ID.")
parser.add_argument(
"--context-length",
type=int,
default=4096,
help="Context length of the model (prompt length in unbounded mode).",
)
parser.add_argument(
"--max-kv-size", type=int, help="Max KV cache size (enables bounded mode)."
)
parser.add_argument("--kv-bits", type=int, help="Bits for KV cache quantization.")
parser.add_argument(
"--kv-group-size", type=int, help="Group size for KV cache quantization."
)
parser.add_argument(
"--tokens-to-generate",
type=int,
default=0,
help="Number of tokens to generate beyond the prompt.",
)
return parser
def print_table(
results: Dict[str, float], mode: str, tokens_to_generate: int = 0
) -> None:
"""
Print a memory usage table in a formatted way.
Args:
results: Dictionary containing memory usage data (in GB unless specified).
mode: Either "Bounded" or "Unbounded" to describe the KV Cache type.
tokens_to_generate: Number of tokens generated (optional, defaults to 0).
"""
# Construct the title dynamically
title = f"*Memory Usage Estimate ({mode} KV Cache"
if tokens_to_generate > 0:
title += f" after generating {tokens_to_generate:,} tokens"
title += "):*"
# Define table formatting constants
LINE_WIDTH = 34
ITEM_COL_WIDTH = 17
MEMORY_COL_WIDTH = 12
# Print header
print(title)
print("-" * LINE_WIDTH)
print(f"| {'Item':<{ITEM_COL_WIDTH}} | {'Memory':<{MEMORY_COL_WIDTH}} |")
print("-" * LINE_WIDTH)
# Define display order and handle missing keys gracefully
display_order = ["Weight", "KV Cache", "Overhead", "Total"]
for key in display_order:
value = results.get(key, 0.0) # Default to 0.0 if key is missing
memory_str = f"{value:.2f} GB"
# Print row (extra spaces for alignment)
if key == "Total":
print("-" * LINE_WIDTH)
print(f"| {key:<{ITEM_COL_WIDTH}} | {memory_str:>{MEMORY_COL_WIDTH}} |")
print("-" * LINE_WIDTH)
# Add footer for Unbounded mode
if mode == "Unbounded" and "per_token_increase" in results:
per_token_gb = results["per_token_increase"]
if per_token_gb > 0: # Avoid division by zero
tokens_per_gb = math.floor(1 / per_token_gb)
print(f"Additional tokens per 1GB increase: {tokens_per_gb:,}")
else:
print("Note: Per-token increase is zero or invalid.")
def main():
parser = setup_arg_parser()
args = parser.parse_args()
results, mode = estimate_mem(
args.model,
args.context_length,
args.max_kv_size,
args.kv_bits,
args.kv_group_size,
args.tokens_to_generate,
)
print_table(results, mode, args.tokens_to_generate)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,369 @@
#!/usr/bin/env python3
# Copyright © 2023-2025 Apple Inc.
"""
Test script to validate memory usage estimation from estimate.py.
Loads a model, runs inference, and compares actual vs estimated memory usage.
"""
import argparse
import gc
import os
import time
from typing import Dict, Optional, Tuple
import mlx.core as mx
from mlx_lm.estimate import (
calculate_head_dim,
compute_bits_per_weight_from_config,
estimate_uram,
fetch_config,
)
from mlx_lm.utils import generate, load, stream_generate
def setup_arg_parser():
parser = argparse.ArgumentParser(description="Test Memory Estimation Accuracy")
parser.add_argument("model", help="Local model path or Hugging Face repo ID.")
parser.add_argument(
"--prompt", type=str, default="Once upon a time", help="Prompt for inference."
)
parser.add_argument(
"--prompt-file",
type=str,
default="prompt.txt",
help="File containing the prompt.",
)
parser.add_argument(
"--num-tokens", type=int, default=50, help="Number of tokens to generate."
)
parser.add_argument("--kv-bits", type=int, help="Bits for KV cache quantization.")
parser.add_argument(
"--kv-group-size",
type=int,
default=64,
help="Group size for KV cache quantization.",
)
parser.add_argument(
"--max-kv-size", type=int, help="Max KV cache size (bounded mode)."
)
parser.add_argument(
"--context-length",
type=int,
default=4096,
help="Context length for estimation.",
)
parser.add_argument(
"--verbose", action="store_true", help="Enable verbose logging."
)
return parser
def get_memory_usage():
"""Get current memory usage in GB."""
used_gb = mx.metal.get_peak_memory() / (1024**3)
mx.metal.reset_peak_memory()
return used_gb
def get_active_memory_gb():
"""Get current active memory usage in GB."""
return mx.metal.get_active_memory() / (1024**3)
def force_gc_and_reset():
"""Force garbage collection and reset memory counters."""
gc.collect()
mx.metal.reset_peak_memory()
time.sleep(0.1) # Small delay to ensure memory operations complete
def measure_kv_cache_memory(
model_pkg, tokenizer, prompt_tokens, num_new_tokens=10, verbose=False
):
"""
Directly measure KV cache memory by comparing memory before and after cache creation.
Args:
model_pkg: The loaded model package (contains model and generate function)
tokenizer: Tokenizer for the model
prompt_tokens: Tokenized prompt
num_new_tokens: Number of new tokens to generate
verbose: Whether to print verbose output
Returns:
Tuple of (kv_cache_size_gb, kv_per_token_gb)
"""
# Force clean memory state
force_gc_and_reset()
# Get baseline memory
baseline = get_active_memory_gb()
if verbose:
print(f"Baseline memory before KV cache: {baseline:.4f} GB")
# Create inputs
inputs = mx.array([prompt_tokens])
# First measure memory with just the prompt - no generation
# Just do a forward pass to build the KV cache
logits = model_pkg.model(inputs)
mx.eval(logits)
# Measure memory with just prompt in KV cache
prompt_kv_memory = get_active_memory_gb() - baseline
if verbose:
print(
f"Memory after prompt KV cache: {prompt_kv_memory:.4f} GB for {len(prompt_tokens)} tokens"
)
# Reset for generation test
force_gc_and_reset()
baseline_with_prompt_kv = get_active_memory_gb()
# For generation, we need to use the mlx_lm generate function, not model.generate
# Create a simple manual generation loop to measure memory impact
input_ids = mx.array([prompt_tokens])
# Generate tokens one by one to measure KV cache growth
for _ in range(num_new_tokens):
# Forward pass
logits = model_pkg.model(input_ids)
next_token_logits = logits[0, -1, :]
next_token = mx.argmax(next_token_logits)
input_ids = mx.concatenate([input_ids, mx.array([[next_token]])], axis=1)
mx.eval(input_ids)
# Measure memory after generation including KV cache
total_kv_memory = (
get_active_memory_gb() - baseline_with_prompt_kv + prompt_kv_memory
)
# Calculate per-token KV cache size
total_tokens = len(prompt_tokens) + num_new_tokens
if num_new_tokens > 0:
per_token_gb = (total_kv_memory - prompt_kv_memory) / num_new_tokens
else:
per_token_gb = 0
if verbose:
print(
f"Total KV cache memory: {total_kv_memory:.4f} GB for {total_tokens} tokens"
)
print(f"Measured KV cache per token: {per_token_gb:.8f} GB")
return total_kv_memory, per_token_gb
def main():
parser = setup_arg_parser()
args = parser.parse_args()
verbose = args.verbose
# Measure baseline memory before loading model
force_gc_and_reset()
baseline_memory_gb = get_memory_usage()
print(f"Baseline memory usage: {baseline_memory_gb:.4f} GB")
print(f"Loading model from {args.model}...")
# Clear memory before starting
force_gc_and_reset()
# Load the model
model_pkg, tokenizer = load(args.model)
# Memory after loading model but before materializing parameters
model_load_memory_gb = get_memory_usage()
# Materialize model parameters to get accurate parameter memory usage
print("Materializing model parameters...")
force_gc_and_reset()
# Force materialization of all model parameters
for p in model_pkg.model.parameters().values():
mx.eval(p)
# Get memory used by materialized parameters
parameter_memory_gb = get_active_memory_gb()
print(f"Model parameters memory: {parameter_memory_gb:.4f} GB")
# Try to read prompt from file if it exists, otherwise use default
prompt = args.prompt
try:
if os.path.exists(args.prompt_file):
with open(args.prompt_file, "r") as f:
file_prompt = f.read().strip()
if file_prompt:
prompt = file_prompt
print(
f"Using prompt from {args.prompt_file} ({len(prompt)} characters)"
)
else:
print(f"Empty prompt file, using default prompt")
except Exception as e:
print(f"Error reading prompt file: {e}, using default prompt")
# Count tokens in the prompt
tokens = tokenizer.encode(prompt)
prompt_token_count = len(tokens)
print(f"Prompt length: {prompt_token_count} tokens")
# Run inference with memory tracking
print(f"Running inference with prompt...")
# Reset memory state
force_gc_and_reset()
# First do a simple forward pass to measure activation memory
print("Running single forward pass...")
# Tokenize input
inputs = mx.array([tokens])
# Create an evaluation function that we will trace
def forward_pass():
return model_pkg.model(inputs)
# Trace the function to capture the actual memory during execution
output = forward_pass()
mx.eval(output)
# Get memory used during forward pass
forward_memory_gb = get_memory_usage()
print(f"Peak memory during forward pass: {forward_memory_gb:.4f} GB")
# Get model config for additional details
config = fetch_config(args.model)
bits_per_weight = compute_bits_per_weight_from_config(config)
# Get the necessary parameters for KV cache calculation from the model config
from mlx_lm.utils import _get_classes
_, model_args_class = _get_classes(config)
model_args = model_args_class.from_dict(config)
# Extract the parameters needed for KV cache calculation
head_dim = calculate_head_dim(config, model_args)
num_kv_heads = getattr(
model_args, "num_key_value_heads", model_args.num_attention_heads
)
num_layers = model_args.num_hidden_layers
# Now directly measure the KV cache memory
print("\nDirectly measuring KV cache memory usage...")
actual_kv_cache_gb, actual_per_token_gb = measure_kv_cache_memory(
model_pkg,
tokenizer,
tokens,
num_new_tokens=min(20, args.num_tokens),
verbose=verbose,
)
# Measure memory during token generation (inference) using proper generate function
print("\nMeasuring memory during full token generation (streaming)...")
force_gc_and_reset()
baseline_for_generation = get_active_memory_gb()
# Use stream_generate to get token-by-token memory measurements
generation_text = ""
peak_memory_gb = 0
total_tokens_generated = 0
token_memory_profile = []
# Stream generation and track memory for each token
for response in stream_generate(
model_pkg.model, tokenizer, prompt, max_tokens=args.num_tokens
):
generation_text += response.text
total_tokens_generated += 1
# Track memory per token
current_memory = response.peak_memory
peak_memory_gb = max(peak_memory_gb, current_memory)
# Record memory for this token
token_memory_profile.append(current_memory)
if verbose and total_tokens_generated % 10 == 0:
print(
f"Generated {total_tokens_generated} tokens, current memory: {current_memory:.4f} GB"
)
# Calculate final memory usage
generation_memory_gb = peak_memory_gb
print(
f"Peak memory during generation of {total_tokens_generated} tokens: {generation_memory_gb:.4f} GB"
)
# You can also add this to get more detailed memory profile analysis
if verbose:
print("\nMemory growth during generation:")
for i, mem in enumerate(token_memory_profile):
if i % 5 == 0 or i == len(token_memory_profile) - 1:
print(f" Token {i+1}: {mem:.4f} GB")
# Calculate activation memory (peak memory during generation minus parameter memory and KV cache)
actual_activation_memory_gb = (
generation_memory_gb - parameter_memory_gb - actual_kv_cache_gb
)
if actual_activation_memory_gb < 0:
# This can happen due to memory reclamation between measurements
actual_activation_memory_gb = (
0.01 * parameter_memory_gb
) # Use a reasonable fallback
# Get estimated memory usage from estimate.py
estimated_results, mode = estimate_uram(
args.model,
context_length=args.context_length,
max_kv_size=args.max_kv_size,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
initial_prompt_length=prompt_token_count,
extra_tokens=total_tokens_generated,
)
# Compare estimation accuracy
print("\n--- MEMORY ESTIMATION ACCURACY ---")
print(f"Model: {args.model}")
print(f"Architecture: {config.get('model_type', 'unknown')}")
print(f"Quantization: {bits_per_weight} bits per weight")
print(f"Total tokens processed: {prompt_token_count + total_tokens_generated}")
print("-" * 40)
print(f"Actual model parameters memory: {parameter_memory_gb:.3g} GB")
print(f"Estimated model parameters memory: {estimated_results['Model']:.3g} GB")
print(
f"Model memory error: {abs(parameter_memory_gb - estimated_results['Model']):.3g} GB"
)
print("-" * 40)
print(f"Actual KV cache memory: {actual_kv_cache_gb:.3g} GB")
print(f"Estimated KV cache memory: {estimated_results['KV Cache']:.3g} GB")
print(
f"KV cache memory error: {abs(actual_kv_cache_gb - estimated_results['KV Cache']):.3g} GB"
)
print("-" * 40)
print(f"Actual per-token KV increase: {actual_per_token_gb:.6g} GB")
print(
f"Estimated per-token KV increase: {estimated_results['per_token_increase']:.6g} GB"
)
print(
f"Per-token KV error: {abs(actual_per_token_gb - estimated_results['per_token_increase']):.6g} GB"
)
print("-" * 40)
print(f"Actual activation memory: {actual_activation_memory_gb:.3g} GB")
print(f"Estimated activation memory: {estimated_results['Activations']:.3g} GB")
print(
f"Activation memory error: {abs(actual_activation_memory_gb - estimated_results['Activations']):.3g} GB"
)
print("-" * 40)
print(f"Total peak memory (actual): {generation_memory_gb:.3g} GB")
print(f"Total memory (estimated): {estimated_results['Total']:.3g} GB")
print(
f"Total memory error: {abs(generation_memory_gb - estimated_results['Total']):.3g} GB"
)
print(f"Mode: {mode}")
if __name__ == "__main__":
main()