mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add memory estimation tool for MLX language models
This commit introduces a comprehensive memory estimation utility for MLX language models, supporting: - Dynamic parameter calculation across diverse model architectures - Handling of quantized and standard models - Estimation of model weights, KV cache, and overhead memory - Support for bounded and unbounded KV cache modes - Flexible configuration via command-line arguments The new tool provides detailed memory usage insights for different model configurations and generation scenarios.
This commit is contained in:
parent
877d2a345b
commit
7ee76a32a4
@ -7,3 +7,9 @@ from ._version import __version__
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
from .utils import convert, generate, load, stream_generate
|
||||
|
||||
|
||||
def get_estimate_mem():
|
||||
from .estimate_memory import estimate_mem
|
||||
|
||||
return estimate_mem
|
||||
|
470
llms/mlx_lm/estimate_memory.py
Normal file
470
llms/mlx_lm/estimate_memory.py
Normal file
@ -0,0 +1,470 @@
|
||||
# Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
|
||||
from huggingface_hub import hf_hub_download, try_to_load_from_cache
|
||||
from mlx_lm.models.base import BaseModelArgs
|
||||
from mlx_lm.utils import _get_classes
|
||||
|
||||
|
||||
def fetch_metadata(model_path: str) -> Tuple[Dict, Optional[int]]:
|
||||
"""Fetch config.json and optionally model.safetensors.index.json for weights size."""
|
||||
config = fetch_config(model_path)
|
||||
model_weight_size = None
|
||||
if not os.path.isdir(model_path):
|
||||
try:
|
||||
index_path = hf_hub_download(
|
||||
repo_id=model_path, filename="model.safetensors.index.json"
|
||||
)
|
||||
with open(index_path, "r") as f:
|
||||
index = json.load(f)
|
||||
model_weight_size = index.get("metadata", {}).get("total_size")
|
||||
except:
|
||||
pass
|
||||
return config, model_weight_size
|
||||
|
||||
|
||||
def fetch_config(model_path: str) -> Dict:
|
||||
"""Fetch or load config.json without downloading the full model, checking cache first."""
|
||||
if os.path.isdir(model_path):
|
||||
config_path = Path(model_path) / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"config.json not found in {model_path}")
|
||||
with open(config_path, "r") as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
cached_path = try_to_load_from_cache(
|
||||
repo_id=model_path, filename="config.json", repo_type="model"
|
||||
)
|
||||
if cached_path and os.path.exists(cached_path):
|
||||
with open(cached_path, "r") as f:
|
||||
return json.load(f)
|
||||
try:
|
||||
config_path = hf_hub_download(repo_id=model_path, filename="config.json")
|
||||
with open(config_path, "r") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch config.json from {model_path}: {str(e)}")
|
||||
|
||||
|
||||
def compute_bits_per_weight_from_config(config: Dict) -> float:
|
||||
"""Infer bits-per-weight from config, defaulting to 16 (FP16) if unquantized."""
|
||||
quantization = config.get("quantization", {})
|
||||
bits = quantization.get("bits")
|
||||
return float(bits) if bits is not None else 16.0
|
||||
|
||||
|
||||
def calc_embedding_params(vocab_size: int, hidden_size: int) -> int:
|
||||
"""Calculate parameters for the embedding layer."""
|
||||
return vocab_size * hidden_size
|
||||
|
||||
|
||||
def calc_attention_params(args, hidden_size: int, num_attention_heads: int) -> int:
|
||||
"""Calculate parameters for one attention layer, handling standard and LoRA variants.
|
||||
|
||||
This function supports both standard multi-head attention (e.g., Mixtral, OLMoE) and
|
||||
LoRA-like attention (e.g., DeepSeek V3) by checking for q_lora_rank, allowing flexibility
|
||||
over the older hardcoded approach that assumed uniform QKV dimensions.
|
||||
"""
|
||||
num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads)
|
||||
head_dim = getattr(args, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
has_bias = getattr(args, "attention_bias", False)
|
||||
|
||||
# Standard attention (Q, K, V, O)
|
||||
if not hasattr(args, "q_lora_rank") or not args.q_lora_rank:
|
||||
q_params = hidden_size * (num_attention_heads * head_dim)
|
||||
k_params = hidden_size * (num_kv_heads * head_dim)
|
||||
v_params = k_params
|
||||
o_params = (num_attention_heads * head_dim) * hidden_size
|
||||
# LoRA-like attention (e.g., DeepSeek V3)
|
||||
else:
|
||||
q_head_dim = getattr(args, "qk_nope_head_dim", 0) + getattr(
|
||||
args, "qk_rope_head_dim", head_dim
|
||||
)
|
||||
v_head_dim = getattr(args, "v_head_dim", head_dim)
|
||||
qk_rope_dim = getattr(args, "qk_rope_head_dim", head_dim)
|
||||
q_params = hidden_size * args.q_lora_rank + args.q_lora_rank * (
|
||||
num_attention_heads * q_head_dim
|
||||
)
|
||||
k_params = hidden_size * (args.kv_lora_rank + qk_rope_dim)
|
||||
v_params = args.kv_lora_rank * (
|
||||
num_attention_heads * (q_head_dim - qk_rope_dim + v_head_dim)
|
||||
)
|
||||
o_params = (num_attention_heads * v_head_dim) * hidden_size
|
||||
|
||||
total = q_params + k_params + v_params + o_params
|
||||
if has_bias:
|
||||
total += (
|
||||
num_attention_heads * head_dim + num_kv_heads * head_dim * 2 + hidden_size
|
||||
)
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def calc_ffn_or_moe_params(
|
||||
args, hidden_size: int, intermediate_size: int, layer_idx: int
|
||||
) -> int:
|
||||
"""Calculate parameters for FFN or MoE layer, switching based on config.
|
||||
|
||||
Unlike the previous hardcoded FFN-only calculation, this function dynamically handles
|
||||
Mixture of Experts (MoE) models like Mixtral, OLMoE, and DeepSeek V3 by detecting
|
||||
expert-related fields (num_experts, n_routed_experts) and adjusting for dense vs.
|
||||
MoE layers, supporting varied intermediate sizes and shared experts.
|
||||
"""
|
||||
num_experts = max(
|
||||
getattr(args, "num_local_experts", 0),
|
||||
getattr(args, "num_experts", 0),
|
||||
getattr(args, "n_routed_experts", 0),
|
||||
)
|
||||
moe_intermediate_size = getattr(args, "moe_intermediate_size", intermediate_size)
|
||||
dense_up_to = (
|
||||
getattr(args, "first_k_dense_replace", 0)
|
||||
if num_experts
|
||||
else args.num_hidden_layers
|
||||
)
|
||||
has_bias = getattr(args, "mlp_bias", False)
|
||||
shared_experts = getattr(args, "n_shared_experts", 0)
|
||||
|
||||
if num_experts and layer_idx >= dense_up_to:
|
||||
# MoE: gate + expert FFNs
|
||||
gate_params = hidden_size * num_experts
|
||||
expert_params = (
|
||||
num_experts * hidden_size * moe_intermediate_size * 3
|
||||
) # gate_proj, up_proj, down_proj
|
||||
shared_params = (
|
||||
shared_experts * hidden_size * moe_intermediate_size * 3
|
||||
if shared_experts
|
||||
else 0
|
||||
)
|
||||
return gate_params + expert_params + shared_params
|
||||
else:
|
||||
# Dense FFN
|
||||
ffn_params = (
|
||||
hidden_size * intermediate_size * 2 + intermediate_size * hidden_size
|
||||
)
|
||||
if has_bias:
|
||||
ffn_params += intermediate_size * 2 + hidden_size
|
||||
return ffn_params
|
||||
|
||||
|
||||
def calc_norm_params(args, hidden_size: int, num_attention_heads: int) -> int:
|
||||
"""Calculate normalization parameters, adjusting for extra norms in complex models.
|
||||
|
||||
This extends the old approach (fixed 2 norms per layer) by adding heuristic support
|
||||
for extra normalization layers (e.g., OLMoE's q_norm, k_norm) in MoE or LoRA models,
|
||||
improving accuracy over the simpler assumption of uniform RMSNorm usage.
|
||||
"""
|
||||
num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads)
|
||||
head_dim = getattr(args, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
|
||||
# Base: input + post-attention RMSNorm
|
||||
total = hidden_size * 2
|
||||
|
||||
# Heuristic: extra norms for MoE or complex attention
|
||||
has_experts = any(
|
||||
getattr(args, attr, 0) > 0
|
||||
for attr in ["num_local_experts", "num_experts", "n_routed_experts"]
|
||||
)
|
||||
if has_experts or hasattr(args, "q_lora_rank"):
|
||||
total += (num_attention_heads * head_dim) + (num_kv_heads * head_dim)
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def calculate_num_parameters(
|
||||
config: Dict, model_args_class: Optional[Type["BaseModelArgs"]] = None
|
||||
) -> int:
|
||||
"""Calculate the total number of parameters in a model based on its config.
|
||||
|
||||
By splitting into separate functions, we now support diverse
|
||||
architectures while maintaining readability and avoiding model-specific hardcoding.
|
||||
"""
|
||||
# Use the imported _get_classes function to get the ModelArgs class
|
||||
if model_args_class is None:
|
||||
_, model_args_class = _get_classes(config)
|
||||
|
||||
args = model_args_class.from_dict(config)
|
||||
|
||||
# Validate required fields
|
||||
required = [
|
||||
"hidden_size",
|
||||
"num_hidden_layers",
|
||||
"vocab_size",
|
||||
"num_attention_heads",
|
||||
"intermediate_size",
|
||||
]
|
||||
missing = [field for field in required if getattr(args, field, None) is None]
|
||||
if missing:
|
||||
raise ValueError(f"Config missing required fields: {missing}")
|
||||
|
||||
# Core config
|
||||
hidden_size = args.hidden_size
|
||||
num_layers = args.num_hidden_layers
|
||||
vocab_size = args.vocab_size
|
||||
num_attention_heads = args.num_attention_heads
|
||||
intermediate_size = args.intermediate_size
|
||||
|
||||
# Total calculation
|
||||
total_params = calc_embedding_params(vocab_size, hidden_size)
|
||||
for layer in range(num_layers):
|
||||
total_params += calc_attention_params(args, hidden_size, num_attention_heads)
|
||||
total_params += calc_ffn_or_moe_params(
|
||||
args, hidden_size, intermediate_size, layer
|
||||
)
|
||||
total_params += calc_norm_params(args, hidden_size, num_attention_heads)
|
||||
total_params += hidden_size # Final norm
|
||||
if not getattr(args, "tie_word_embeddings", True):
|
||||
total_params += hidden_size * vocab_size # LM head
|
||||
|
||||
return total_params
|
||||
|
||||
|
||||
def calculate_head_dim(config: Dict, args: BaseModelArgs) -> int:
|
||||
"""Infer head dimension dynamically from config or args."""
|
||||
head_dim = getattr(args, "head_dim", None)
|
||||
if head_dim is None:
|
||||
if "hidden_size" not in config or "num_attention_heads" not in config:
|
||||
raise ValueError(
|
||||
"Cannot compute head_dim: missing hidden_size or num_attention_heads"
|
||||
)
|
||||
head_dim = config["hidden_size"] // config["num_attention_heads"]
|
||||
return head_dim
|
||||
|
||||
|
||||
def estimate_mem(
|
||||
model_path: str,
|
||||
context_length: int = 4096,
|
||||
max_kv_size: Optional[int] = None,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: Optional[int] = None,
|
||||
tokens_to_generate: int = 0,
|
||||
) -> Tuple[Dict[str, float], str]:
|
||||
"""
|
||||
Estimate the memory usage of a model.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model.
|
||||
context_length: Context length of the model (prompt length in unbounded mode).
|
||||
max_kv_size: Maximum size of the KV cache (for bounded mode).
|
||||
kv_bits: Number of bits to use for quantized KV cache.
|
||||
kv_group_size: Group size to use for quantized KV cache.
|
||||
tokens_to_generate: Number of tokens to generate beyond the prompt.
|
||||
|
||||
Returns:
|
||||
A tuple of (results, mode) where results is a dictionary of memory usage
|
||||
and mode is a string indicating the mode of the KV cache.
|
||||
"""
|
||||
config, model_weight_size = fetch_metadata(model_path)
|
||||
bits_per_weight = compute_bits_per_weight_from_config(config)
|
||||
|
||||
# Determine the model class
|
||||
_, model_args_class = _get_classes(config)
|
||||
args = model_args_class.from_dict(config)
|
||||
|
||||
# Calculate the number of parameters
|
||||
num_parameters = calculate_num_parameters(config, model_args_class)
|
||||
|
||||
# Extract model architecture parameters needed for memory calculations
|
||||
num_layers = args.num_hidden_layers
|
||||
num_kv_heads = getattr(args, "num_key_value_heads", args.num_attention_heads)
|
||||
head_dim = calculate_head_dim(config, args)
|
||||
|
||||
# Default to fp16 (2 bytes per element) for KV cache unless quantized
|
||||
bytes_per_element = 2
|
||||
|
||||
# If kv_bits and kv_group_size are not provided, try to read from config
|
||||
if kv_bits is None or kv_group_size is None:
|
||||
# Try to get quantization settings from config
|
||||
quantization = config.get("quantization", {})
|
||||
quantization_config = config.get("quantization_config", {})
|
||||
|
||||
# Use the first available quantization config
|
||||
quant_info = quantization or quantization_config
|
||||
|
||||
if quant_info:
|
||||
kv_bits = kv_bits or quant_info.get("bits")
|
||||
kv_group_size = kv_group_size or quant_info.get("group_size")
|
||||
|
||||
# Calculate the model weight memory usage
|
||||
bytes_per_parameter = bits_per_weight / 8
|
||||
if model_weight_size:
|
||||
# Use the size from safetensors index if available
|
||||
model_size_gb = model_weight_size / (1024**3)
|
||||
else:
|
||||
# Calculate from parameter count
|
||||
model_size_gb = (num_parameters * bytes_per_parameter) / (1024**3)
|
||||
|
||||
# Estimate tokenizer size
|
||||
vocab_size = config.get("vocab_size", args.vocab_size)
|
||||
fixed_overhead_bytes = 25 * 1024 * 1024
|
||||
avg_token_size_bytes = 650
|
||||
tokenizer_size_bytes = (vocab_size * avg_token_size_bytes) + fixed_overhead_bytes
|
||||
tokenizer_size_gb = tokenizer_size_bytes / (1024**3)
|
||||
|
||||
# Determine the mode
|
||||
mode = "Bounded" if max_kv_size else "Unbounded"
|
||||
|
||||
# KV length is fixed to max_kv_size in bounded mode, or context_length in unbounded mode
|
||||
kv_length = max_kv_size if mode == "Bounded" else context_length
|
||||
|
||||
# Default step size from cache.py is 256
|
||||
step_size = 256
|
||||
kv_length_padded = ((kv_length + step_size - 1) // step_size) * step_size
|
||||
|
||||
# Calculate KV cache size based on whether quantization is used
|
||||
if kv_bits and kv_group_size:
|
||||
# Quantized cache calculations
|
||||
groups_per_head_dim = (head_dim + kv_group_size - 1) // kv_group_size
|
||||
elements_per_int = 8 * 4 // kv_bits
|
||||
|
||||
data_size = (
|
||||
num_kv_heads * kv_length_padded * (head_dim // elements_per_int) * 4
|
||||
) / (1024**3)
|
||||
quant_overhead = (
|
||||
num_kv_heads * kv_length_padded * groups_per_head_dim * 2 * 2
|
||||
) / (1024**3)
|
||||
per_layer_kv_size = 2 * (data_size + quant_overhead)
|
||||
|
||||
elements_per_token = (head_dim // elements_per_int) * 4
|
||||
scales_zeros_per_token = groups_per_head_dim * 2 * 2
|
||||
per_token_bytes = (
|
||||
2 * num_kv_heads * (elements_per_token + scales_zeros_per_token)
|
||||
)
|
||||
per_token_increase = (per_token_bytes * num_layers) / (1024**3)
|
||||
else:
|
||||
# Standard fp16 cache
|
||||
per_layer_kv_size = (
|
||||
2 * num_kv_heads * kv_length_padded * head_dim * bytes_per_element
|
||||
) / (1024**3)
|
||||
per_token_increase = (
|
||||
2 * num_kv_heads * head_dim * bytes_per_element * num_layers
|
||||
) / (1024**3)
|
||||
|
||||
total_kv_cache_size = num_layers * per_layer_kv_size
|
||||
|
||||
# Add the memory for generated tokens if specified
|
||||
if tokens_to_generate > 0:
|
||||
total_kv_cache_size += tokens_to_generate * per_token_increase
|
||||
|
||||
# For inference in MLX, estimate activation memory
|
||||
activation_size_gb = 0.03 * model_size_gb
|
||||
|
||||
overhead_gb = tokenizer_size_gb + activation_size_gb + (model_size_gb * 0.05)
|
||||
|
||||
# Total memory usage
|
||||
total_memory_gb = model_size_gb + total_kv_cache_size + overhead_gb
|
||||
|
||||
results = {
|
||||
"Weight": model_size_gb,
|
||||
"KV Cache": total_kv_cache_size,
|
||||
"Overhead": overhead_gb,
|
||||
"Total": total_memory_gb,
|
||||
"per_token_increase": per_token_increase,
|
||||
}
|
||||
|
||||
return results, mode
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
parser = argparse.ArgumentParser(description="MLX Model URAM Estimation Tool")
|
||||
parser.add_argument("model", help="Local model path or Hugging Face repo ID.")
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Context length of the model (prompt length in unbounded mode).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size", type=int, help="Max KV cache size (enables bounded mode)."
|
||||
)
|
||||
parser.add_argument("--kv-bits", type=int, help="Bits for KV cache quantization.")
|
||||
parser.add_argument(
|
||||
"--kv-group-size", type=int, help="Group size for KV cache quantization."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens-to-generate",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of tokens to generate beyond the prompt.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def print_table(
|
||||
results: Dict[str, float], mode: str, tokens_to_generate: int = 0
|
||||
) -> None:
|
||||
"""
|
||||
Print a memory usage table in a formatted way.
|
||||
|
||||
Args:
|
||||
results: Dictionary containing memory usage data (in GB unless specified).
|
||||
mode: Either "Bounded" or "Unbounded" to describe the KV Cache type.
|
||||
tokens_to_generate: Number of tokens generated (optional, defaults to 0).
|
||||
"""
|
||||
# Construct the title dynamically
|
||||
title = f"*Memory Usage Estimate ({mode} KV Cache"
|
||||
if tokens_to_generate > 0:
|
||||
title += f" after generating {tokens_to_generate:,} tokens"
|
||||
title += "):*"
|
||||
|
||||
# Define table formatting constants
|
||||
LINE_WIDTH = 34
|
||||
ITEM_COL_WIDTH = 17
|
||||
MEMORY_COL_WIDTH = 12
|
||||
|
||||
# Print header
|
||||
print(title)
|
||||
print("-" * LINE_WIDTH)
|
||||
print(f"| {'Item':<{ITEM_COL_WIDTH}} | {'Memory':<{MEMORY_COL_WIDTH}} |")
|
||||
print("-" * LINE_WIDTH)
|
||||
|
||||
# Define display order and handle missing keys gracefully
|
||||
display_order = ["Weight", "KV Cache", "Overhead", "Total"]
|
||||
for key in display_order:
|
||||
value = results.get(key, 0.0) # Default to 0.0 if key is missing
|
||||
memory_str = f"{value:.2f} GB"
|
||||
|
||||
# Print row (extra spaces for alignment)
|
||||
if key == "Total":
|
||||
print("-" * LINE_WIDTH)
|
||||
print(f"| {key:<{ITEM_COL_WIDTH}} | {memory_str:>{MEMORY_COL_WIDTH}} |")
|
||||
|
||||
print("-" * LINE_WIDTH)
|
||||
|
||||
# Add footer for Unbounded mode
|
||||
if mode == "Unbounded" and "per_token_increase" in results:
|
||||
per_token_gb = results["per_token_increase"]
|
||||
if per_token_gb > 0: # Avoid division by zero
|
||||
tokens_per_gb = math.floor(1 / per_token_gb)
|
||||
print(f"Additional tokens per 1GB increase: {tokens_per_gb:,}")
|
||||
else:
|
||||
print("Note: Per-token increase is zero or invalid.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
results, mode = estimate_mem(
|
||||
args.model,
|
||||
args.context_length,
|
||||
args.max_kv_size,
|
||||
args.kv_bits,
|
||||
args.kv_group_size,
|
||||
args.tokens_to_generate,
|
||||
)
|
||||
|
||||
print_table(results, mode, args.tokens_to_generate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
369
llms/test_memory_estimation.py
Normal file
369
llms/test_memory_estimation.py
Normal file
@ -0,0 +1,369 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
"""
|
||||
Test script to validate memory usage estimation from estimate.py.
|
||||
Loads a model, runs inference, and compares actual vs estimated memory usage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.estimate import (
|
||||
calculate_head_dim,
|
||||
compute_bits_per_weight_from_config,
|
||||
estimate_uram,
|
||||
fetch_config,
|
||||
)
|
||||
from mlx_lm.utils import generate, load, stream_generate
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
parser = argparse.ArgumentParser(description="Test Memory Estimation Accuracy")
|
||||
parser.add_argument("model", help="Local model path or Hugging Face repo ID.")
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="Once upon a time", help="Prompt for inference."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-file",
|
||||
type=str,
|
||||
default="prompt.txt",
|
||||
help="File containing the prompt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tokens", type=int, default=50, help="Number of tokens to generate."
|
||||
)
|
||||
parser.add_argument("--kv-bits", type=int, help="Bits for KV cache quantization.")
|
||||
parser.add_argument(
|
||||
"--kv-group-size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Group size for KV cache quantization.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size", type=int, help="Max KV cache size (bounded mode)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Context length for estimation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Enable verbose logging."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
"""Get current memory usage in GB."""
|
||||
used_gb = mx.metal.get_peak_memory() / (1024**3)
|
||||
mx.metal.reset_peak_memory()
|
||||
return used_gb
|
||||
|
||||
|
||||
def get_active_memory_gb():
|
||||
"""Get current active memory usage in GB."""
|
||||
return mx.metal.get_active_memory() / (1024**3)
|
||||
|
||||
|
||||
def force_gc_and_reset():
|
||||
"""Force garbage collection and reset memory counters."""
|
||||
gc.collect()
|
||||
mx.metal.reset_peak_memory()
|
||||
time.sleep(0.1) # Small delay to ensure memory operations complete
|
||||
|
||||
|
||||
def measure_kv_cache_memory(
|
||||
model_pkg, tokenizer, prompt_tokens, num_new_tokens=10, verbose=False
|
||||
):
|
||||
"""
|
||||
Directly measure KV cache memory by comparing memory before and after cache creation.
|
||||
|
||||
Args:
|
||||
model_pkg: The loaded model package (contains model and generate function)
|
||||
tokenizer: Tokenizer for the model
|
||||
prompt_tokens: Tokenized prompt
|
||||
num_new_tokens: Number of new tokens to generate
|
||||
verbose: Whether to print verbose output
|
||||
|
||||
Returns:
|
||||
Tuple of (kv_cache_size_gb, kv_per_token_gb)
|
||||
"""
|
||||
# Force clean memory state
|
||||
force_gc_and_reset()
|
||||
|
||||
# Get baseline memory
|
||||
baseline = get_active_memory_gb()
|
||||
if verbose:
|
||||
print(f"Baseline memory before KV cache: {baseline:.4f} GB")
|
||||
|
||||
# Create inputs
|
||||
inputs = mx.array([prompt_tokens])
|
||||
|
||||
# First measure memory with just the prompt - no generation
|
||||
# Just do a forward pass to build the KV cache
|
||||
logits = model_pkg.model(inputs)
|
||||
mx.eval(logits)
|
||||
|
||||
# Measure memory with just prompt in KV cache
|
||||
prompt_kv_memory = get_active_memory_gb() - baseline
|
||||
if verbose:
|
||||
print(
|
||||
f"Memory after prompt KV cache: {prompt_kv_memory:.4f} GB for {len(prompt_tokens)} tokens"
|
||||
)
|
||||
|
||||
# Reset for generation test
|
||||
force_gc_and_reset()
|
||||
baseline_with_prompt_kv = get_active_memory_gb()
|
||||
|
||||
# For generation, we need to use the mlx_lm generate function, not model.generate
|
||||
# Create a simple manual generation loop to measure memory impact
|
||||
input_ids = mx.array([prompt_tokens])
|
||||
|
||||
# Generate tokens one by one to measure KV cache growth
|
||||
for _ in range(num_new_tokens):
|
||||
# Forward pass
|
||||
logits = model_pkg.model(input_ids)
|
||||
next_token_logits = logits[0, -1, :]
|
||||
next_token = mx.argmax(next_token_logits)
|
||||
input_ids = mx.concatenate([input_ids, mx.array([[next_token]])], axis=1)
|
||||
mx.eval(input_ids)
|
||||
|
||||
# Measure memory after generation including KV cache
|
||||
total_kv_memory = (
|
||||
get_active_memory_gb() - baseline_with_prompt_kv + prompt_kv_memory
|
||||
)
|
||||
|
||||
# Calculate per-token KV cache size
|
||||
total_tokens = len(prompt_tokens) + num_new_tokens
|
||||
if num_new_tokens > 0:
|
||||
per_token_gb = (total_kv_memory - prompt_kv_memory) / num_new_tokens
|
||||
else:
|
||||
per_token_gb = 0
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"Total KV cache memory: {total_kv_memory:.4f} GB for {total_tokens} tokens"
|
||||
)
|
||||
print(f"Measured KV cache per token: {per_token_gb:.8f} GB")
|
||||
|
||||
return total_kv_memory, per_token_gb
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
verbose = args.verbose
|
||||
|
||||
# Measure baseline memory before loading model
|
||||
force_gc_and_reset()
|
||||
baseline_memory_gb = get_memory_usage()
|
||||
print(f"Baseline memory usage: {baseline_memory_gb:.4f} GB")
|
||||
|
||||
print(f"Loading model from {args.model}...")
|
||||
|
||||
# Clear memory before starting
|
||||
force_gc_and_reset()
|
||||
|
||||
# Load the model
|
||||
model_pkg, tokenizer = load(args.model)
|
||||
|
||||
# Memory after loading model but before materializing parameters
|
||||
model_load_memory_gb = get_memory_usage()
|
||||
|
||||
# Materialize model parameters to get accurate parameter memory usage
|
||||
print("Materializing model parameters...")
|
||||
force_gc_and_reset()
|
||||
|
||||
# Force materialization of all model parameters
|
||||
for p in model_pkg.model.parameters().values():
|
||||
mx.eval(p)
|
||||
|
||||
# Get memory used by materialized parameters
|
||||
parameter_memory_gb = get_active_memory_gb()
|
||||
print(f"Model parameters memory: {parameter_memory_gb:.4f} GB")
|
||||
|
||||
# Try to read prompt from file if it exists, otherwise use default
|
||||
prompt = args.prompt
|
||||
try:
|
||||
if os.path.exists(args.prompt_file):
|
||||
with open(args.prompt_file, "r") as f:
|
||||
file_prompt = f.read().strip()
|
||||
if file_prompt:
|
||||
prompt = file_prompt
|
||||
print(
|
||||
f"Using prompt from {args.prompt_file} ({len(prompt)} characters)"
|
||||
)
|
||||
else:
|
||||
print(f"Empty prompt file, using default prompt")
|
||||
except Exception as e:
|
||||
print(f"Error reading prompt file: {e}, using default prompt")
|
||||
|
||||
# Count tokens in the prompt
|
||||
tokens = tokenizer.encode(prompt)
|
||||
prompt_token_count = len(tokens)
|
||||
print(f"Prompt length: {prompt_token_count} tokens")
|
||||
|
||||
# Run inference with memory tracking
|
||||
print(f"Running inference with prompt...")
|
||||
|
||||
# Reset memory state
|
||||
force_gc_and_reset()
|
||||
|
||||
# First do a simple forward pass to measure activation memory
|
||||
print("Running single forward pass...")
|
||||
|
||||
# Tokenize input
|
||||
inputs = mx.array([tokens])
|
||||
|
||||
# Create an evaluation function that we will trace
|
||||
def forward_pass():
|
||||
return model_pkg.model(inputs)
|
||||
|
||||
# Trace the function to capture the actual memory during execution
|
||||
output = forward_pass()
|
||||
mx.eval(output)
|
||||
|
||||
# Get memory used during forward pass
|
||||
forward_memory_gb = get_memory_usage()
|
||||
print(f"Peak memory during forward pass: {forward_memory_gb:.4f} GB")
|
||||
|
||||
# Get model config for additional details
|
||||
config = fetch_config(args.model)
|
||||
bits_per_weight = compute_bits_per_weight_from_config(config)
|
||||
|
||||
# Get the necessary parameters for KV cache calculation from the model config
|
||||
from mlx_lm.utils import _get_classes
|
||||
|
||||
_, model_args_class = _get_classes(config)
|
||||
model_args = model_args_class.from_dict(config)
|
||||
|
||||
# Extract the parameters needed for KV cache calculation
|
||||
head_dim = calculate_head_dim(config, model_args)
|
||||
num_kv_heads = getattr(
|
||||
model_args, "num_key_value_heads", model_args.num_attention_heads
|
||||
)
|
||||
num_layers = model_args.num_hidden_layers
|
||||
|
||||
# Now directly measure the KV cache memory
|
||||
print("\nDirectly measuring KV cache memory usage...")
|
||||
actual_kv_cache_gb, actual_per_token_gb = measure_kv_cache_memory(
|
||||
model_pkg,
|
||||
tokenizer,
|
||||
tokens,
|
||||
num_new_tokens=min(20, args.num_tokens),
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# Measure memory during token generation (inference) using proper generate function
|
||||
print("\nMeasuring memory during full token generation (streaming)...")
|
||||
force_gc_and_reset()
|
||||
baseline_for_generation = get_active_memory_gb()
|
||||
|
||||
# Use stream_generate to get token-by-token memory measurements
|
||||
generation_text = ""
|
||||
peak_memory_gb = 0
|
||||
total_tokens_generated = 0
|
||||
token_memory_profile = []
|
||||
|
||||
# Stream generation and track memory for each token
|
||||
for response in stream_generate(
|
||||
model_pkg.model, tokenizer, prompt, max_tokens=args.num_tokens
|
||||
):
|
||||
generation_text += response.text
|
||||
total_tokens_generated += 1
|
||||
|
||||
# Track memory per token
|
||||
current_memory = response.peak_memory
|
||||
peak_memory_gb = max(peak_memory_gb, current_memory)
|
||||
|
||||
# Record memory for this token
|
||||
token_memory_profile.append(current_memory)
|
||||
|
||||
if verbose and total_tokens_generated % 10 == 0:
|
||||
print(
|
||||
f"Generated {total_tokens_generated} tokens, current memory: {current_memory:.4f} GB"
|
||||
)
|
||||
|
||||
# Calculate final memory usage
|
||||
generation_memory_gb = peak_memory_gb
|
||||
print(
|
||||
f"Peak memory during generation of {total_tokens_generated} tokens: {generation_memory_gb:.4f} GB"
|
||||
)
|
||||
|
||||
# You can also add this to get more detailed memory profile analysis
|
||||
if verbose:
|
||||
print("\nMemory growth during generation:")
|
||||
for i, mem in enumerate(token_memory_profile):
|
||||
if i % 5 == 0 or i == len(token_memory_profile) - 1:
|
||||
print(f" Token {i+1}: {mem:.4f} GB")
|
||||
|
||||
# Calculate activation memory (peak memory during generation minus parameter memory and KV cache)
|
||||
actual_activation_memory_gb = (
|
||||
generation_memory_gb - parameter_memory_gb - actual_kv_cache_gb
|
||||
)
|
||||
if actual_activation_memory_gb < 0:
|
||||
# This can happen due to memory reclamation between measurements
|
||||
actual_activation_memory_gb = (
|
||||
0.01 * parameter_memory_gb
|
||||
) # Use a reasonable fallback
|
||||
|
||||
# Get estimated memory usage from estimate.py
|
||||
estimated_results, mode = estimate_uram(
|
||||
args.model,
|
||||
context_length=args.context_length,
|
||||
max_kv_size=args.max_kv_size,
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
initial_prompt_length=prompt_token_count,
|
||||
extra_tokens=total_tokens_generated,
|
||||
)
|
||||
|
||||
# Compare estimation accuracy
|
||||
print("\n--- MEMORY ESTIMATION ACCURACY ---")
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Architecture: {config.get('model_type', 'unknown')}")
|
||||
print(f"Quantization: {bits_per_weight} bits per weight")
|
||||
print(f"Total tokens processed: {prompt_token_count + total_tokens_generated}")
|
||||
print("-" * 40)
|
||||
print(f"Actual model parameters memory: {parameter_memory_gb:.3g} GB")
|
||||
print(f"Estimated model parameters memory: {estimated_results['Model']:.3g} GB")
|
||||
print(
|
||||
f"Model memory error: {abs(parameter_memory_gb - estimated_results['Model']):.3g} GB"
|
||||
)
|
||||
print("-" * 40)
|
||||
print(f"Actual KV cache memory: {actual_kv_cache_gb:.3g} GB")
|
||||
print(f"Estimated KV cache memory: {estimated_results['KV Cache']:.3g} GB")
|
||||
print(
|
||||
f"KV cache memory error: {abs(actual_kv_cache_gb - estimated_results['KV Cache']):.3g} GB"
|
||||
)
|
||||
print("-" * 40)
|
||||
print(f"Actual per-token KV increase: {actual_per_token_gb:.6g} GB")
|
||||
print(
|
||||
f"Estimated per-token KV increase: {estimated_results['per_token_increase']:.6g} GB"
|
||||
)
|
||||
print(
|
||||
f"Per-token KV error: {abs(actual_per_token_gb - estimated_results['per_token_increase']):.6g} GB"
|
||||
)
|
||||
print("-" * 40)
|
||||
print(f"Actual activation memory: {actual_activation_memory_gb:.3g} GB")
|
||||
print(f"Estimated activation memory: {estimated_results['Activations']:.3g} GB")
|
||||
print(
|
||||
f"Activation memory error: {abs(actual_activation_memory_gb - estimated_results['Activations']):.3g} GB"
|
||||
)
|
||||
print("-" * 40)
|
||||
print(f"Total peak memory (actual): {generation_memory_gb:.3g} GB")
|
||||
print(f"Total memory (estimated): {estimated_results['Total']:.3g} GB")
|
||||
print(
|
||||
f"Total memory error: {abs(generation_memory_gb - estimated_results['Total']):.3g} GB"
|
||||
)
|
||||
print(f"Mode: {mode}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user