mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add memory estimation tool for MLX language models
This commit introduces a comprehensive memory estimation utility for MLX language models, supporting: - Dynamic parameter calculation across diverse model architectures - Handling of quantized and standard models - Estimation of model weights, KV cache, and overhead memory - Support for bounded and unbounded KV cache modes - Flexible configuration via command-line arguments The new tool provides detailed memory usage insights for different model configurations and generation scenarios.
This commit is contained in:
parent
877d2a345b
commit
7ee76a32a4
@ -7,3 +7,9 @@ from ._version import __version__
|
|||||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||||
|
|
||||||
from .utils import convert, generate, load, stream_generate
|
from .utils import convert, generate, load, stream_generate
|
||||||
|
|
||||||
|
|
||||||
|
def get_estimate_mem():
|
||||||
|
from .estimate_memory import estimate_mem
|
||||||
|
|
||||||
|
return estimate_mem
|
||||||
|
470
llms/mlx_lm/estimate_memory.py
Normal file
470
llms/mlx_lm/estimate_memory.py
Normal file
@ -0,0 +1,470 @@
|
|||||||
|
# Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download, try_to_load_from_cache
|
||||||
|
from mlx_lm.models.base import BaseModelArgs
|
||||||
|
from mlx_lm.utils import _get_classes
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_metadata(model_path: str) -> Tuple[Dict, Optional[int]]:
|
||||||
|
"""Fetch config.json and optionally model.safetensors.index.json for weights size."""
|
||||||
|
config = fetch_config(model_path)
|
||||||
|
model_weight_size = None
|
||||||
|
if not os.path.isdir(model_path):
|
||||||
|
try:
|
||||||
|
index_path = hf_hub_download(
|
||||||
|
repo_id=model_path, filename="model.safetensors.index.json"
|
||||||
|
)
|
||||||
|
with open(index_path, "r") as f:
|
||||||
|
index = json.load(f)
|
||||||
|
model_weight_size = index.get("metadata", {}).get("total_size")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return config, model_weight_size
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_config(model_path: str) -> Dict:
|
||||||
|
"""Fetch or load config.json without downloading the full model, checking cache first."""
|
||||||
|
if os.path.isdir(model_path):
|
||||||
|
config_path = Path(model_path) / "config.json"
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"config.json not found in {model_path}")
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
else:
|
||||||
|
cached_path = try_to_load_from_cache(
|
||||||
|
repo_id=model_path, filename="config.json", repo_type="model"
|
||||||
|
)
|
||||||
|
if cached_path and os.path.exists(cached_path):
|
||||||
|
with open(cached_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
try:
|
||||||
|
config_path = hf_hub_download(repo_id=model_path, filename="config.json")
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to fetch config.json from {model_path}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bits_per_weight_from_config(config: Dict) -> float:
|
||||||
|
"""Infer bits-per-weight from config, defaulting to 16 (FP16) if unquantized."""
|
||||||
|
quantization = config.get("quantization", {})
|
||||||
|
bits = quantization.get("bits")
|
||||||
|
return float(bits) if bits is not None else 16.0
|
||||||
|
|
||||||
|
|
||||||
|
def calc_embedding_params(vocab_size: int, hidden_size: int) -> int:
|
||||||
|
"""Calculate parameters for the embedding layer."""
|
||||||
|
return vocab_size * hidden_size
|
||||||
|
|
||||||
|
|
||||||
|
def calc_attention_params(args, hidden_size: int, num_attention_heads: int) -> int:
|
||||||
|
"""Calculate parameters for one attention layer, handling standard and LoRA variants.
|
||||||
|
|
||||||
|
This function supports both standard multi-head attention (e.g., Mixtral, OLMoE) and
|
||||||
|
LoRA-like attention (e.g., DeepSeek V3) by checking for q_lora_rank, allowing flexibility
|
||||||
|
over the older hardcoded approach that assumed uniform QKV dimensions.
|
||||||
|
"""
|
||||||
|
num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads)
|
||||||
|
head_dim = getattr(args, "head_dim", None)
|
||||||
|
if head_dim is None:
|
||||||
|
head_dim = hidden_size // num_attention_heads
|
||||||
|
has_bias = getattr(args, "attention_bias", False)
|
||||||
|
|
||||||
|
# Standard attention (Q, K, V, O)
|
||||||
|
if not hasattr(args, "q_lora_rank") or not args.q_lora_rank:
|
||||||
|
q_params = hidden_size * (num_attention_heads * head_dim)
|
||||||
|
k_params = hidden_size * (num_kv_heads * head_dim)
|
||||||
|
v_params = k_params
|
||||||
|
o_params = (num_attention_heads * head_dim) * hidden_size
|
||||||
|
# LoRA-like attention (e.g., DeepSeek V3)
|
||||||
|
else:
|
||||||
|
q_head_dim = getattr(args, "qk_nope_head_dim", 0) + getattr(
|
||||||
|
args, "qk_rope_head_dim", head_dim
|
||||||
|
)
|
||||||
|
v_head_dim = getattr(args, "v_head_dim", head_dim)
|
||||||
|
qk_rope_dim = getattr(args, "qk_rope_head_dim", head_dim)
|
||||||
|
q_params = hidden_size * args.q_lora_rank + args.q_lora_rank * (
|
||||||
|
num_attention_heads * q_head_dim
|
||||||
|
)
|
||||||
|
k_params = hidden_size * (args.kv_lora_rank + qk_rope_dim)
|
||||||
|
v_params = args.kv_lora_rank * (
|
||||||
|
num_attention_heads * (q_head_dim - qk_rope_dim + v_head_dim)
|
||||||
|
)
|
||||||
|
o_params = (num_attention_heads * v_head_dim) * hidden_size
|
||||||
|
|
||||||
|
total = q_params + k_params + v_params + o_params
|
||||||
|
if has_bias:
|
||||||
|
total += (
|
||||||
|
num_attention_heads * head_dim + num_kv_heads * head_dim * 2 + hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def calc_ffn_or_moe_params(
|
||||||
|
args, hidden_size: int, intermediate_size: int, layer_idx: int
|
||||||
|
) -> int:
|
||||||
|
"""Calculate parameters for FFN or MoE layer, switching based on config.
|
||||||
|
|
||||||
|
Unlike the previous hardcoded FFN-only calculation, this function dynamically handles
|
||||||
|
Mixture of Experts (MoE) models like Mixtral, OLMoE, and DeepSeek V3 by detecting
|
||||||
|
expert-related fields (num_experts, n_routed_experts) and adjusting for dense vs.
|
||||||
|
MoE layers, supporting varied intermediate sizes and shared experts.
|
||||||
|
"""
|
||||||
|
num_experts = max(
|
||||||
|
getattr(args, "num_local_experts", 0),
|
||||||
|
getattr(args, "num_experts", 0),
|
||||||
|
getattr(args, "n_routed_experts", 0),
|
||||||
|
)
|
||||||
|
moe_intermediate_size = getattr(args, "moe_intermediate_size", intermediate_size)
|
||||||
|
dense_up_to = (
|
||||||
|
getattr(args, "first_k_dense_replace", 0)
|
||||||
|
if num_experts
|
||||||
|
else args.num_hidden_layers
|
||||||
|
)
|
||||||
|
has_bias = getattr(args, "mlp_bias", False)
|
||||||
|
shared_experts = getattr(args, "n_shared_experts", 0)
|
||||||
|
|
||||||
|
if num_experts and layer_idx >= dense_up_to:
|
||||||
|
# MoE: gate + expert FFNs
|
||||||
|
gate_params = hidden_size * num_experts
|
||||||
|
expert_params = (
|
||||||
|
num_experts * hidden_size * moe_intermediate_size * 3
|
||||||
|
) # gate_proj, up_proj, down_proj
|
||||||
|
shared_params = (
|
||||||
|
shared_experts * hidden_size * moe_intermediate_size * 3
|
||||||
|
if shared_experts
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
return gate_params + expert_params + shared_params
|
||||||
|
else:
|
||||||
|
# Dense FFN
|
||||||
|
ffn_params = (
|
||||||
|
hidden_size * intermediate_size * 2 + intermediate_size * hidden_size
|
||||||
|
)
|
||||||
|
if has_bias:
|
||||||
|
ffn_params += intermediate_size * 2 + hidden_size
|
||||||
|
return ffn_params
|
||||||
|
|
||||||
|
|
||||||
|
def calc_norm_params(args, hidden_size: int, num_attention_heads: int) -> int:
|
||||||
|
"""Calculate normalization parameters, adjusting for extra norms in complex models.
|
||||||
|
|
||||||
|
This extends the old approach (fixed 2 norms per layer) by adding heuristic support
|
||||||
|
for extra normalization layers (e.g., OLMoE's q_norm, k_norm) in MoE or LoRA models,
|
||||||
|
improving accuracy over the simpler assumption of uniform RMSNorm usage.
|
||||||
|
"""
|
||||||
|
num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads)
|
||||||
|
head_dim = getattr(args, "head_dim", None)
|
||||||
|
if head_dim is None:
|
||||||
|
head_dim = hidden_size // num_attention_heads
|
||||||
|
|
||||||
|
# Base: input + post-attention RMSNorm
|
||||||
|
total = hidden_size * 2
|
||||||
|
|
||||||
|
# Heuristic: extra norms for MoE or complex attention
|
||||||
|
has_experts = any(
|
||||||
|
getattr(args, attr, 0) > 0
|
||||||
|
for attr in ["num_local_experts", "num_experts", "n_routed_experts"]
|
||||||
|
)
|
||||||
|
if has_experts or hasattr(args, "q_lora_rank"):
|
||||||
|
total += (num_attention_heads * head_dim) + (num_kv_heads * head_dim)
|
||||||
|
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_num_parameters(
|
||||||
|
config: Dict, model_args_class: Optional[Type["BaseModelArgs"]] = None
|
||||||
|
) -> int:
|
||||||
|
"""Calculate the total number of parameters in a model based on its config.
|
||||||
|
|
||||||
|
By splitting into separate functions, we now support diverse
|
||||||
|
architectures while maintaining readability and avoiding model-specific hardcoding.
|
||||||
|
"""
|
||||||
|
# Use the imported _get_classes function to get the ModelArgs class
|
||||||
|
if model_args_class is None:
|
||||||
|
_, model_args_class = _get_classes(config)
|
||||||
|
|
||||||
|
args = model_args_class.from_dict(config)
|
||||||
|
|
||||||
|
# Validate required fields
|
||||||
|
required = [
|
||||||
|
"hidden_size",
|
||||||
|
"num_hidden_layers",
|
||||||
|
"vocab_size",
|
||||||
|
"num_attention_heads",
|
||||||
|
"intermediate_size",
|
||||||
|
]
|
||||||
|
missing = [field for field in required if getattr(args, field, None) is None]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Config missing required fields: {missing}")
|
||||||
|
|
||||||
|
# Core config
|
||||||
|
hidden_size = args.hidden_size
|
||||||
|
num_layers = args.num_hidden_layers
|
||||||
|
vocab_size = args.vocab_size
|
||||||
|
num_attention_heads = args.num_attention_heads
|
||||||
|
intermediate_size = args.intermediate_size
|
||||||
|
|
||||||
|
# Total calculation
|
||||||
|
total_params = calc_embedding_params(vocab_size, hidden_size)
|
||||||
|
for layer in range(num_layers):
|
||||||
|
total_params += calc_attention_params(args, hidden_size, num_attention_heads)
|
||||||
|
total_params += calc_ffn_or_moe_params(
|
||||||
|
args, hidden_size, intermediate_size, layer
|
||||||
|
)
|
||||||
|
total_params += calc_norm_params(args, hidden_size, num_attention_heads)
|
||||||
|
total_params += hidden_size # Final norm
|
||||||
|
if not getattr(args, "tie_word_embeddings", True):
|
||||||
|
total_params += hidden_size * vocab_size # LM head
|
||||||
|
|
||||||
|
return total_params
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_head_dim(config: Dict, args: BaseModelArgs) -> int:
|
||||||
|
"""Infer head dimension dynamically from config or args."""
|
||||||
|
head_dim = getattr(args, "head_dim", None)
|
||||||
|
if head_dim is None:
|
||||||
|
if "hidden_size" not in config or "num_attention_heads" not in config:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot compute head_dim: missing hidden_size or num_attention_heads"
|
||||||
|
)
|
||||||
|
head_dim = config["hidden_size"] // config["num_attention_heads"]
|
||||||
|
return head_dim
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_mem(
|
||||||
|
model_path: str,
|
||||||
|
context_length: int = 4096,
|
||||||
|
max_kv_size: Optional[int] = None,
|
||||||
|
kv_bits: Optional[int] = None,
|
||||||
|
kv_group_size: Optional[int] = None,
|
||||||
|
tokens_to_generate: int = 0,
|
||||||
|
) -> Tuple[Dict[str, float], str]:
|
||||||
|
"""
|
||||||
|
Estimate the memory usage of a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model.
|
||||||
|
context_length: Context length of the model (prompt length in unbounded mode).
|
||||||
|
max_kv_size: Maximum size of the KV cache (for bounded mode).
|
||||||
|
kv_bits: Number of bits to use for quantized KV cache.
|
||||||
|
kv_group_size: Group size to use for quantized KV cache.
|
||||||
|
tokens_to_generate: Number of tokens to generate beyond the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (results, mode) where results is a dictionary of memory usage
|
||||||
|
and mode is a string indicating the mode of the KV cache.
|
||||||
|
"""
|
||||||
|
config, model_weight_size = fetch_metadata(model_path)
|
||||||
|
bits_per_weight = compute_bits_per_weight_from_config(config)
|
||||||
|
|
||||||
|
# Determine the model class
|
||||||
|
_, model_args_class = _get_classes(config)
|
||||||
|
args = model_args_class.from_dict(config)
|
||||||
|
|
||||||
|
# Calculate the number of parameters
|
||||||
|
num_parameters = calculate_num_parameters(config, model_args_class)
|
||||||
|
|
||||||
|
# Extract model architecture parameters needed for memory calculations
|
||||||
|
num_layers = args.num_hidden_layers
|
||||||
|
num_kv_heads = getattr(args, "num_key_value_heads", args.num_attention_heads)
|
||||||
|
head_dim = calculate_head_dim(config, args)
|
||||||
|
|
||||||
|
# Default to fp16 (2 bytes per element) for KV cache unless quantized
|
||||||
|
bytes_per_element = 2
|
||||||
|
|
||||||
|
# If kv_bits and kv_group_size are not provided, try to read from config
|
||||||
|
if kv_bits is None or kv_group_size is None:
|
||||||
|
# Try to get quantization settings from config
|
||||||
|
quantization = config.get("quantization", {})
|
||||||
|
quantization_config = config.get("quantization_config", {})
|
||||||
|
|
||||||
|
# Use the first available quantization config
|
||||||
|
quant_info = quantization or quantization_config
|
||||||
|
|
||||||
|
if quant_info:
|
||||||
|
kv_bits = kv_bits or quant_info.get("bits")
|
||||||
|
kv_group_size = kv_group_size or quant_info.get("group_size")
|
||||||
|
|
||||||
|
# Calculate the model weight memory usage
|
||||||
|
bytes_per_parameter = bits_per_weight / 8
|
||||||
|
if model_weight_size:
|
||||||
|
# Use the size from safetensors index if available
|
||||||
|
model_size_gb = model_weight_size / (1024**3)
|
||||||
|
else:
|
||||||
|
# Calculate from parameter count
|
||||||
|
model_size_gb = (num_parameters * bytes_per_parameter) / (1024**3)
|
||||||
|
|
||||||
|
# Estimate tokenizer size
|
||||||
|
vocab_size = config.get("vocab_size", args.vocab_size)
|
||||||
|
fixed_overhead_bytes = 25 * 1024 * 1024
|
||||||
|
avg_token_size_bytes = 650
|
||||||
|
tokenizer_size_bytes = (vocab_size * avg_token_size_bytes) + fixed_overhead_bytes
|
||||||
|
tokenizer_size_gb = tokenizer_size_bytes / (1024**3)
|
||||||
|
|
||||||
|
# Determine the mode
|
||||||
|
mode = "Bounded" if max_kv_size else "Unbounded"
|
||||||
|
|
||||||
|
# KV length is fixed to max_kv_size in bounded mode, or context_length in unbounded mode
|
||||||
|
kv_length = max_kv_size if mode == "Bounded" else context_length
|
||||||
|
|
||||||
|
# Default step size from cache.py is 256
|
||||||
|
step_size = 256
|
||||||
|
kv_length_padded = ((kv_length + step_size - 1) // step_size) * step_size
|
||||||
|
|
||||||
|
# Calculate KV cache size based on whether quantization is used
|
||||||
|
if kv_bits and kv_group_size:
|
||||||
|
# Quantized cache calculations
|
||||||
|
groups_per_head_dim = (head_dim + kv_group_size - 1) // kv_group_size
|
||||||
|
elements_per_int = 8 * 4 // kv_bits
|
||||||
|
|
||||||
|
data_size = (
|
||||||
|
num_kv_heads * kv_length_padded * (head_dim // elements_per_int) * 4
|
||||||
|
) / (1024**3)
|
||||||
|
quant_overhead = (
|
||||||
|
num_kv_heads * kv_length_padded * groups_per_head_dim * 2 * 2
|
||||||
|
) / (1024**3)
|
||||||
|
per_layer_kv_size = 2 * (data_size + quant_overhead)
|
||||||
|
|
||||||
|
elements_per_token = (head_dim // elements_per_int) * 4
|
||||||
|
scales_zeros_per_token = groups_per_head_dim * 2 * 2
|
||||||
|
per_token_bytes = (
|
||||||
|
2 * num_kv_heads * (elements_per_token + scales_zeros_per_token)
|
||||||
|
)
|
||||||
|
per_token_increase = (per_token_bytes * num_layers) / (1024**3)
|
||||||
|
else:
|
||||||
|
# Standard fp16 cache
|
||||||
|
per_layer_kv_size = (
|
||||||
|
2 * num_kv_heads * kv_length_padded * head_dim * bytes_per_element
|
||||||
|
) / (1024**3)
|
||||||
|
per_token_increase = (
|
||||||
|
2 * num_kv_heads * head_dim * bytes_per_element * num_layers
|
||||||
|
) / (1024**3)
|
||||||
|
|
||||||
|
total_kv_cache_size = num_layers * per_layer_kv_size
|
||||||
|
|
||||||
|
# Add the memory for generated tokens if specified
|
||||||
|
if tokens_to_generate > 0:
|
||||||
|
total_kv_cache_size += tokens_to_generate * per_token_increase
|
||||||
|
|
||||||
|
# For inference in MLX, estimate activation memory
|
||||||
|
activation_size_gb = 0.03 * model_size_gb
|
||||||
|
|
||||||
|
overhead_gb = tokenizer_size_gb + activation_size_gb + (model_size_gb * 0.05)
|
||||||
|
|
||||||
|
# Total memory usage
|
||||||
|
total_memory_gb = model_size_gb + total_kv_cache_size + overhead_gb
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"Weight": model_size_gb,
|
||||||
|
"KV Cache": total_kv_cache_size,
|
||||||
|
"Overhead": overhead_gb,
|
||||||
|
"Total": total_memory_gb,
|
||||||
|
"per_token_increase": per_token_increase,
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, mode
|
||||||
|
|
||||||
|
|
||||||
|
def setup_arg_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="MLX Model URAM Estimation Tool")
|
||||||
|
parser.add_argument("model", help="Local model path or Hugging Face repo ID.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-length",
|
||||||
|
type=int,
|
||||||
|
default=4096,
|
||||||
|
help="Context length of the model (prompt length in unbounded mode).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-kv-size", type=int, help="Max KV cache size (enables bounded mode)."
|
||||||
|
)
|
||||||
|
parser.add_argument("--kv-bits", type=int, help="Bits for KV cache quantization.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-group-size", type=int, help="Group size for KV cache quantization."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens-to-generate",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Number of tokens to generate beyond the prompt.",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def print_table(
|
||||||
|
results: Dict[str, float], mode: str, tokens_to_generate: int = 0
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Print a memory usage table in a formatted way.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Dictionary containing memory usage data (in GB unless specified).
|
||||||
|
mode: Either "Bounded" or "Unbounded" to describe the KV Cache type.
|
||||||
|
tokens_to_generate: Number of tokens generated (optional, defaults to 0).
|
||||||
|
"""
|
||||||
|
# Construct the title dynamically
|
||||||
|
title = f"*Memory Usage Estimate ({mode} KV Cache"
|
||||||
|
if tokens_to_generate > 0:
|
||||||
|
title += f" after generating {tokens_to_generate:,} tokens"
|
||||||
|
title += "):*"
|
||||||
|
|
||||||
|
# Define table formatting constants
|
||||||
|
LINE_WIDTH = 34
|
||||||
|
ITEM_COL_WIDTH = 17
|
||||||
|
MEMORY_COL_WIDTH = 12
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
print(title)
|
||||||
|
print("-" * LINE_WIDTH)
|
||||||
|
print(f"| {'Item':<{ITEM_COL_WIDTH}} | {'Memory':<{MEMORY_COL_WIDTH}} |")
|
||||||
|
print("-" * LINE_WIDTH)
|
||||||
|
|
||||||
|
# Define display order and handle missing keys gracefully
|
||||||
|
display_order = ["Weight", "KV Cache", "Overhead", "Total"]
|
||||||
|
for key in display_order:
|
||||||
|
value = results.get(key, 0.0) # Default to 0.0 if key is missing
|
||||||
|
memory_str = f"{value:.2f} GB"
|
||||||
|
|
||||||
|
# Print row (extra spaces for alignment)
|
||||||
|
if key == "Total":
|
||||||
|
print("-" * LINE_WIDTH)
|
||||||
|
print(f"| {key:<{ITEM_COL_WIDTH}} | {memory_str:>{MEMORY_COL_WIDTH}} |")
|
||||||
|
|
||||||
|
print("-" * LINE_WIDTH)
|
||||||
|
|
||||||
|
# Add footer for Unbounded mode
|
||||||
|
if mode == "Unbounded" and "per_token_increase" in results:
|
||||||
|
per_token_gb = results["per_token_increase"]
|
||||||
|
if per_token_gb > 0: # Avoid division by zero
|
||||||
|
tokens_per_gb = math.floor(1 / per_token_gb)
|
||||||
|
print(f"Additional tokens per 1GB increase: {tokens_per_gb:,}")
|
||||||
|
else:
|
||||||
|
print("Note: Per-token increase is zero or invalid.")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = setup_arg_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
results, mode = estimate_mem(
|
||||||
|
args.model,
|
||||||
|
args.context_length,
|
||||||
|
args.max_kv_size,
|
||||||
|
args.kv_bits,
|
||||||
|
args.kv_group_size,
|
||||||
|
args.tokens_to_generate,
|
||||||
|
)
|
||||||
|
|
||||||
|
print_table(results, mode, args.tokens_to_generate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
369
llms/test_memory_estimation.py
Normal file
369
llms/test_memory_estimation.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Test script to validate memory usage estimation from estimate.py.
|
||||||
|
Loads a model, runs inference, and compares actual vs estimated memory usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.estimate import (
|
||||||
|
calculate_head_dim,
|
||||||
|
compute_bits_per_weight_from_config,
|
||||||
|
estimate_uram,
|
||||||
|
fetch_config,
|
||||||
|
)
|
||||||
|
from mlx_lm.utils import generate, load, stream_generate
|
||||||
|
|
||||||
|
|
||||||
|
def setup_arg_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Test Memory Estimation Accuracy")
|
||||||
|
parser.add_argument("model", help="Local model path or Hugging Face repo ID.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt", type=str, default="Once upon a time", help="Prompt for inference."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-file",
|
||||||
|
type=str,
|
||||||
|
default="prompt.txt",
|
||||||
|
help="File containing the prompt.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-tokens", type=int, default=50, help="Number of tokens to generate."
|
||||||
|
)
|
||||||
|
parser.add_argument("--kv-bits", type=int, help="Bits for KV cache quantization.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-group-size",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Group size for KV cache quantization.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-kv-size", type=int, help="Max KV cache size (bounded mode)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-length",
|
||||||
|
type=int,
|
||||||
|
default=4096,
|
||||||
|
help="Context length for estimation.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", action="store_true", help="Enable verbose logging."
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
"""Get current memory usage in GB."""
|
||||||
|
used_gb = mx.metal.get_peak_memory() / (1024**3)
|
||||||
|
mx.metal.reset_peak_memory()
|
||||||
|
return used_gb
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_memory_gb():
|
||||||
|
"""Get current active memory usage in GB."""
|
||||||
|
return mx.metal.get_active_memory() / (1024**3)
|
||||||
|
|
||||||
|
|
||||||
|
def force_gc_and_reset():
|
||||||
|
"""Force garbage collection and reset memory counters."""
|
||||||
|
gc.collect()
|
||||||
|
mx.metal.reset_peak_memory()
|
||||||
|
time.sleep(0.1) # Small delay to ensure memory operations complete
|
||||||
|
|
||||||
|
|
||||||
|
def measure_kv_cache_memory(
|
||||||
|
model_pkg, tokenizer, prompt_tokens, num_new_tokens=10, verbose=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Directly measure KV cache memory by comparing memory before and after cache creation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_pkg: The loaded model package (contains model and generate function)
|
||||||
|
tokenizer: Tokenizer for the model
|
||||||
|
prompt_tokens: Tokenized prompt
|
||||||
|
num_new_tokens: Number of new tokens to generate
|
||||||
|
verbose: Whether to print verbose output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (kv_cache_size_gb, kv_per_token_gb)
|
||||||
|
"""
|
||||||
|
# Force clean memory state
|
||||||
|
force_gc_and_reset()
|
||||||
|
|
||||||
|
# Get baseline memory
|
||||||
|
baseline = get_active_memory_gb()
|
||||||
|
if verbose:
|
||||||
|
print(f"Baseline memory before KV cache: {baseline:.4f} GB")
|
||||||
|
|
||||||
|
# Create inputs
|
||||||
|
inputs = mx.array([prompt_tokens])
|
||||||
|
|
||||||
|
# First measure memory with just the prompt - no generation
|
||||||
|
# Just do a forward pass to build the KV cache
|
||||||
|
logits = model_pkg.model(inputs)
|
||||||
|
mx.eval(logits)
|
||||||
|
|
||||||
|
# Measure memory with just prompt in KV cache
|
||||||
|
prompt_kv_memory = get_active_memory_gb() - baseline
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"Memory after prompt KV cache: {prompt_kv_memory:.4f} GB for {len(prompt_tokens)} tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset for generation test
|
||||||
|
force_gc_and_reset()
|
||||||
|
baseline_with_prompt_kv = get_active_memory_gb()
|
||||||
|
|
||||||
|
# For generation, we need to use the mlx_lm generate function, not model.generate
|
||||||
|
# Create a simple manual generation loop to measure memory impact
|
||||||
|
input_ids = mx.array([prompt_tokens])
|
||||||
|
|
||||||
|
# Generate tokens one by one to measure KV cache growth
|
||||||
|
for _ in range(num_new_tokens):
|
||||||
|
# Forward pass
|
||||||
|
logits = model_pkg.model(input_ids)
|
||||||
|
next_token_logits = logits[0, -1, :]
|
||||||
|
next_token = mx.argmax(next_token_logits)
|
||||||
|
input_ids = mx.concatenate([input_ids, mx.array([[next_token]])], axis=1)
|
||||||
|
mx.eval(input_ids)
|
||||||
|
|
||||||
|
# Measure memory after generation including KV cache
|
||||||
|
total_kv_memory = (
|
||||||
|
get_active_memory_gb() - baseline_with_prompt_kv + prompt_kv_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate per-token KV cache size
|
||||||
|
total_tokens = len(prompt_tokens) + num_new_tokens
|
||||||
|
if num_new_tokens > 0:
|
||||||
|
per_token_gb = (total_kv_memory - prompt_kv_memory) / num_new_tokens
|
||||||
|
else:
|
||||||
|
per_token_gb = 0
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"Total KV cache memory: {total_kv_memory:.4f} GB for {total_tokens} tokens"
|
||||||
|
)
|
||||||
|
print(f"Measured KV cache per token: {per_token_gb:.8f} GB")
|
||||||
|
|
||||||
|
return total_kv_memory, per_token_gb
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = setup_arg_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
verbose = args.verbose
|
||||||
|
|
||||||
|
# Measure baseline memory before loading model
|
||||||
|
force_gc_and_reset()
|
||||||
|
baseline_memory_gb = get_memory_usage()
|
||||||
|
print(f"Baseline memory usage: {baseline_memory_gb:.4f} GB")
|
||||||
|
|
||||||
|
print(f"Loading model from {args.model}...")
|
||||||
|
|
||||||
|
# Clear memory before starting
|
||||||
|
force_gc_and_reset()
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
model_pkg, tokenizer = load(args.model)
|
||||||
|
|
||||||
|
# Memory after loading model but before materializing parameters
|
||||||
|
model_load_memory_gb = get_memory_usage()
|
||||||
|
|
||||||
|
# Materialize model parameters to get accurate parameter memory usage
|
||||||
|
print("Materializing model parameters...")
|
||||||
|
force_gc_and_reset()
|
||||||
|
|
||||||
|
# Force materialization of all model parameters
|
||||||
|
for p in model_pkg.model.parameters().values():
|
||||||
|
mx.eval(p)
|
||||||
|
|
||||||
|
# Get memory used by materialized parameters
|
||||||
|
parameter_memory_gb = get_active_memory_gb()
|
||||||
|
print(f"Model parameters memory: {parameter_memory_gb:.4f} GB")
|
||||||
|
|
||||||
|
# Try to read prompt from file if it exists, otherwise use default
|
||||||
|
prompt = args.prompt
|
||||||
|
try:
|
||||||
|
if os.path.exists(args.prompt_file):
|
||||||
|
with open(args.prompt_file, "r") as f:
|
||||||
|
file_prompt = f.read().strip()
|
||||||
|
if file_prompt:
|
||||||
|
prompt = file_prompt
|
||||||
|
print(
|
||||||
|
f"Using prompt from {args.prompt_file} ({len(prompt)} characters)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"Empty prompt file, using default prompt")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading prompt file: {e}, using default prompt")
|
||||||
|
|
||||||
|
# Count tokens in the prompt
|
||||||
|
tokens = tokenizer.encode(prompt)
|
||||||
|
prompt_token_count = len(tokens)
|
||||||
|
print(f"Prompt length: {prompt_token_count} tokens")
|
||||||
|
|
||||||
|
# Run inference with memory tracking
|
||||||
|
print(f"Running inference with prompt...")
|
||||||
|
|
||||||
|
# Reset memory state
|
||||||
|
force_gc_and_reset()
|
||||||
|
|
||||||
|
# First do a simple forward pass to measure activation memory
|
||||||
|
print("Running single forward pass...")
|
||||||
|
|
||||||
|
# Tokenize input
|
||||||
|
inputs = mx.array([tokens])
|
||||||
|
|
||||||
|
# Create an evaluation function that we will trace
|
||||||
|
def forward_pass():
|
||||||
|
return model_pkg.model(inputs)
|
||||||
|
|
||||||
|
# Trace the function to capture the actual memory during execution
|
||||||
|
output = forward_pass()
|
||||||
|
mx.eval(output)
|
||||||
|
|
||||||
|
# Get memory used during forward pass
|
||||||
|
forward_memory_gb = get_memory_usage()
|
||||||
|
print(f"Peak memory during forward pass: {forward_memory_gb:.4f} GB")
|
||||||
|
|
||||||
|
# Get model config for additional details
|
||||||
|
config = fetch_config(args.model)
|
||||||
|
bits_per_weight = compute_bits_per_weight_from_config(config)
|
||||||
|
|
||||||
|
# Get the necessary parameters for KV cache calculation from the model config
|
||||||
|
from mlx_lm.utils import _get_classes
|
||||||
|
|
||||||
|
_, model_args_class = _get_classes(config)
|
||||||
|
model_args = model_args_class.from_dict(config)
|
||||||
|
|
||||||
|
# Extract the parameters needed for KV cache calculation
|
||||||
|
head_dim = calculate_head_dim(config, model_args)
|
||||||
|
num_kv_heads = getattr(
|
||||||
|
model_args, "num_key_value_heads", model_args.num_attention_heads
|
||||||
|
)
|
||||||
|
num_layers = model_args.num_hidden_layers
|
||||||
|
|
||||||
|
# Now directly measure the KV cache memory
|
||||||
|
print("\nDirectly measuring KV cache memory usage...")
|
||||||
|
actual_kv_cache_gb, actual_per_token_gb = measure_kv_cache_memory(
|
||||||
|
model_pkg,
|
||||||
|
tokenizer,
|
||||||
|
tokens,
|
||||||
|
num_new_tokens=min(20, args.num_tokens),
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Measure memory during token generation (inference) using proper generate function
|
||||||
|
print("\nMeasuring memory during full token generation (streaming)...")
|
||||||
|
force_gc_and_reset()
|
||||||
|
baseline_for_generation = get_active_memory_gb()
|
||||||
|
|
||||||
|
# Use stream_generate to get token-by-token memory measurements
|
||||||
|
generation_text = ""
|
||||||
|
peak_memory_gb = 0
|
||||||
|
total_tokens_generated = 0
|
||||||
|
token_memory_profile = []
|
||||||
|
|
||||||
|
# Stream generation and track memory for each token
|
||||||
|
for response in stream_generate(
|
||||||
|
model_pkg.model, tokenizer, prompt, max_tokens=args.num_tokens
|
||||||
|
):
|
||||||
|
generation_text += response.text
|
||||||
|
total_tokens_generated += 1
|
||||||
|
|
||||||
|
# Track memory per token
|
||||||
|
current_memory = response.peak_memory
|
||||||
|
peak_memory_gb = max(peak_memory_gb, current_memory)
|
||||||
|
|
||||||
|
# Record memory for this token
|
||||||
|
token_memory_profile.append(current_memory)
|
||||||
|
|
||||||
|
if verbose and total_tokens_generated % 10 == 0:
|
||||||
|
print(
|
||||||
|
f"Generated {total_tokens_generated} tokens, current memory: {current_memory:.4f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate final memory usage
|
||||||
|
generation_memory_gb = peak_memory_gb
|
||||||
|
print(
|
||||||
|
f"Peak memory during generation of {total_tokens_generated} tokens: {generation_memory_gb:.4f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# You can also add this to get more detailed memory profile analysis
|
||||||
|
if verbose:
|
||||||
|
print("\nMemory growth during generation:")
|
||||||
|
for i, mem in enumerate(token_memory_profile):
|
||||||
|
if i % 5 == 0 or i == len(token_memory_profile) - 1:
|
||||||
|
print(f" Token {i+1}: {mem:.4f} GB")
|
||||||
|
|
||||||
|
# Calculate activation memory (peak memory during generation minus parameter memory and KV cache)
|
||||||
|
actual_activation_memory_gb = (
|
||||||
|
generation_memory_gb - parameter_memory_gb - actual_kv_cache_gb
|
||||||
|
)
|
||||||
|
if actual_activation_memory_gb < 0:
|
||||||
|
# This can happen due to memory reclamation between measurements
|
||||||
|
actual_activation_memory_gb = (
|
||||||
|
0.01 * parameter_memory_gb
|
||||||
|
) # Use a reasonable fallback
|
||||||
|
|
||||||
|
# Get estimated memory usage from estimate.py
|
||||||
|
estimated_results, mode = estimate_uram(
|
||||||
|
args.model,
|
||||||
|
context_length=args.context_length,
|
||||||
|
max_kv_size=args.max_kv_size,
|
||||||
|
kv_bits=args.kv_bits,
|
||||||
|
kv_group_size=args.kv_group_size,
|
||||||
|
initial_prompt_length=prompt_token_count,
|
||||||
|
extra_tokens=total_tokens_generated,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare estimation accuracy
|
||||||
|
print("\n--- MEMORY ESTIMATION ACCURACY ---")
|
||||||
|
print(f"Model: {args.model}")
|
||||||
|
print(f"Architecture: {config.get('model_type', 'unknown')}")
|
||||||
|
print(f"Quantization: {bits_per_weight} bits per weight")
|
||||||
|
print(f"Total tokens processed: {prompt_token_count + total_tokens_generated}")
|
||||||
|
print("-" * 40)
|
||||||
|
print(f"Actual model parameters memory: {parameter_memory_gb:.3g} GB")
|
||||||
|
print(f"Estimated model parameters memory: {estimated_results['Model']:.3g} GB")
|
||||||
|
print(
|
||||||
|
f"Model memory error: {abs(parameter_memory_gb - estimated_results['Model']):.3g} GB"
|
||||||
|
)
|
||||||
|
print("-" * 40)
|
||||||
|
print(f"Actual KV cache memory: {actual_kv_cache_gb:.3g} GB")
|
||||||
|
print(f"Estimated KV cache memory: {estimated_results['KV Cache']:.3g} GB")
|
||||||
|
print(
|
||||||
|
f"KV cache memory error: {abs(actual_kv_cache_gb - estimated_results['KV Cache']):.3g} GB"
|
||||||
|
)
|
||||||
|
print("-" * 40)
|
||||||
|
print(f"Actual per-token KV increase: {actual_per_token_gb:.6g} GB")
|
||||||
|
print(
|
||||||
|
f"Estimated per-token KV increase: {estimated_results['per_token_increase']:.6g} GB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Per-token KV error: {abs(actual_per_token_gb - estimated_results['per_token_increase']):.6g} GB"
|
||||||
|
)
|
||||||
|
print("-" * 40)
|
||||||
|
print(f"Actual activation memory: {actual_activation_memory_gb:.3g} GB")
|
||||||
|
print(f"Estimated activation memory: {estimated_results['Activations']:.3g} GB")
|
||||||
|
print(
|
||||||
|
f"Activation memory error: {abs(actual_activation_memory_gb - estimated_results['Activations']):.3g} GB"
|
||||||
|
)
|
||||||
|
print("-" * 40)
|
||||||
|
print(f"Total peak memory (actual): {generation_memory_gb:.3g} GB")
|
||||||
|
print(f"Total memory (estimated): {estimated_results['Total']:.3g} GB")
|
||||||
|
print(
|
||||||
|
f"Total memory error: {abs(generation_memory_gb - estimated_results['Total']):.3g} GB"
|
||||||
|
)
|
||||||
|
print(f"Mode: {mode}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user