mlx-examples/llms/mlx_lm/estimate_memory.py

471 lines
17 KiB
Python
Raw Normal View History

# Copyright © 2023-2025 Apple Inc.
import argparse
import json
import math
import os
from pathlib import Path
from typing import Dict, Optional, Tuple, Type
from huggingface_hub import hf_hub_download, try_to_load_from_cache
from mlx_lm.models.base import BaseModelArgs
from mlx_lm.utils import _get_classes
def fetch_metadata(model_path: str) -> Tuple[Dict, Optional[int]]:
"""Fetch config.json and optionally model.safetensors.index.json for weights size."""
config = fetch_config(model_path)
model_weight_size = None
if not os.path.isdir(model_path):
try:
index_path = hf_hub_download(
repo_id=model_path, filename="model.safetensors.index.json"
)
with open(index_path, "r") as f:
index = json.load(f)
model_weight_size = index.get("metadata", {}).get("total_size")
except:
pass
return config, model_weight_size
def fetch_config(model_path: str) -> Dict:
"""Fetch or load config.json without downloading the full model, checking cache first."""
if os.path.isdir(model_path):
config_path = Path(model_path) / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"config.json not found in {model_path}")
with open(config_path, "r") as f:
return json.load(f)
else:
cached_path = try_to_load_from_cache(
repo_id=model_path, filename="config.json", repo_type="model"
)
if cached_path and os.path.exists(cached_path):
with open(cached_path, "r") as f:
return json.load(f)
try:
config_path = hf_hub_download(repo_id=model_path, filename="config.json")
with open(config_path, "r") as f:
return json.load(f)
except Exception as e:
raise ValueError(f"Failed to fetch config.json from {model_path}: {str(e)}")
def compute_bits_per_weight_from_config(config: Dict) -> float:
"""Infer bits-per-weight from config, defaulting to 16 (FP16) if unquantized."""
quantization = config.get("quantization", {})
bits = quantization.get("bits")
return float(bits) if bits is not None else 16.0
def calc_embedding_params(vocab_size: int, hidden_size: int) -> int:
"""Calculate parameters for the embedding layer."""
return vocab_size * hidden_size
def calc_attention_params(args, hidden_size: int, num_attention_heads: int) -> int:
"""Calculate parameters for one attention layer, handling standard and LoRA variants.
This function supports both standard multi-head attention (e.g., Mixtral, OLMoE) and
LoRA-like attention (e.g., DeepSeek V3) by checking for q_lora_rank, allowing flexibility
over the older hardcoded approach that assumed uniform QKV dimensions.
"""
num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads)
head_dim = getattr(args, "head_dim", None)
if head_dim is None:
head_dim = hidden_size // num_attention_heads
has_bias = getattr(args, "attention_bias", False)
# Standard attention (Q, K, V, O)
if not hasattr(args, "q_lora_rank") or not args.q_lora_rank:
q_params = hidden_size * (num_attention_heads * head_dim)
k_params = hidden_size * (num_kv_heads * head_dim)
v_params = k_params
o_params = (num_attention_heads * head_dim) * hidden_size
# LoRA-like attention (e.g., DeepSeek V3)
else:
q_head_dim = getattr(args, "qk_nope_head_dim", 0) + getattr(
args, "qk_rope_head_dim", head_dim
)
v_head_dim = getattr(args, "v_head_dim", head_dim)
qk_rope_dim = getattr(args, "qk_rope_head_dim", head_dim)
q_params = hidden_size * args.q_lora_rank + args.q_lora_rank * (
num_attention_heads * q_head_dim
)
k_params = hidden_size * (args.kv_lora_rank + qk_rope_dim)
v_params = args.kv_lora_rank * (
num_attention_heads * (q_head_dim - qk_rope_dim + v_head_dim)
)
o_params = (num_attention_heads * v_head_dim) * hidden_size
total = q_params + k_params + v_params + o_params
if has_bias:
total += (
num_attention_heads * head_dim + num_kv_heads * head_dim * 2 + hidden_size
)
return total
def calc_ffn_or_moe_params(
args, hidden_size: int, intermediate_size: int, layer_idx: int
) -> int:
"""Calculate parameters for FFN or MoE layer, switching based on config.
Unlike the previous hardcoded FFN-only calculation, this function dynamically handles
Mixture of Experts (MoE) models like Mixtral, OLMoE, and DeepSeek V3 by detecting
expert-related fields (num_experts, n_routed_experts) and adjusting for dense vs.
MoE layers, supporting varied intermediate sizes and shared experts.
"""
num_experts = max(
getattr(args, "num_local_experts", 0),
getattr(args, "num_experts", 0),
getattr(args, "n_routed_experts", 0),
)
moe_intermediate_size = getattr(args, "moe_intermediate_size", intermediate_size)
dense_up_to = (
getattr(args, "first_k_dense_replace", 0)
if num_experts
else args.num_hidden_layers
)
has_bias = getattr(args, "mlp_bias", False)
shared_experts = getattr(args, "n_shared_experts", 0)
if num_experts and layer_idx >= dense_up_to:
# MoE: gate + expert FFNs
gate_params = hidden_size * num_experts
expert_params = (
num_experts * hidden_size * moe_intermediate_size * 3
) # gate_proj, up_proj, down_proj
shared_params = (
shared_experts * hidden_size * moe_intermediate_size * 3
if shared_experts
else 0
)
return gate_params + expert_params + shared_params
else:
# Dense FFN
ffn_params = (
hidden_size * intermediate_size * 2 + intermediate_size * hidden_size
)
if has_bias:
ffn_params += intermediate_size * 2 + hidden_size
return ffn_params
def calc_norm_params(args, hidden_size: int, num_attention_heads: int) -> int:
"""Calculate normalization parameters, adjusting for extra norms in complex models.
This extends the old approach (fixed 2 norms per layer) by adding heuristic support
for extra normalization layers (e.g., OLMoE's q_norm, k_norm) in MoE or LoRA models,
improving accuracy over the simpler assumption of uniform RMSNorm usage.
"""
num_kv_heads = getattr(args, "num_key_value_heads", num_attention_heads)
head_dim = getattr(args, "head_dim", None)
if head_dim is None:
head_dim = hidden_size // num_attention_heads
# Base: input + post-attention RMSNorm
total = hidden_size * 2
# Heuristic: extra norms for MoE or complex attention
has_experts = any(
getattr(args, attr, 0) > 0
for attr in ["num_local_experts", "num_experts", "n_routed_experts"]
)
if has_experts or hasattr(args, "q_lora_rank"):
total += (num_attention_heads * head_dim) + (num_kv_heads * head_dim)
return total
def calculate_num_parameters(
config: Dict, model_args_class: Optional[Type["BaseModelArgs"]] = None
) -> int:
"""Calculate the total number of parameters in a model based on its config.
By splitting into separate functions, we now support diverse
architectures while maintaining readability and avoiding model-specific hardcoding.
"""
# Use the imported _get_classes function to get the ModelArgs class
if model_args_class is None:
_, model_args_class = _get_classes(config)
args = model_args_class.from_dict(config)
# Validate required fields
required = [
"hidden_size",
"num_hidden_layers",
"vocab_size",
"num_attention_heads",
"intermediate_size",
]
missing = [field for field in required if getattr(args, field, None) is None]
if missing:
raise ValueError(f"Config missing required fields: {missing}")
# Core config
hidden_size = args.hidden_size
num_layers = args.num_hidden_layers
vocab_size = args.vocab_size
num_attention_heads = args.num_attention_heads
intermediate_size = args.intermediate_size
# Total calculation
total_params = calc_embedding_params(vocab_size, hidden_size)
for layer in range(num_layers):
total_params += calc_attention_params(args, hidden_size, num_attention_heads)
total_params += calc_ffn_or_moe_params(
args, hidden_size, intermediate_size, layer
)
total_params += calc_norm_params(args, hidden_size, num_attention_heads)
total_params += hidden_size # Final norm
if not getattr(args, "tie_word_embeddings", True):
total_params += hidden_size * vocab_size # LM head
return total_params
def calculate_head_dim(config: Dict, args: BaseModelArgs) -> int:
"""Infer head dimension dynamically from config or args."""
head_dim = getattr(args, "head_dim", None)
if head_dim is None:
if "hidden_size" not in config or "num_attention_heads" not in config:
raise ValueError(
"Cannot compute head_dim: missing hidden_size or num_attention_heads"
)
head_dim = config["hidden_size"] // config["num_attention_heads"]
return head_dim
def estimate_mem(
model_path: str,
context_length: int = 4096,
max_kv_size: Optional[int] = None,
kv_bits: Optional[int] = None,
kv_group_size: Optional[int] = None,
tokens_to_generate: int = 0,
) -> Tuple[Dict[str, float], str]:
"""
Estimate the memory usage of a model.
Args:
model_path: Path to the model.
context_length: Context length of the model (prompt length in unbounded mode).
max_kv_size: Maximum size of the KV cache (for bounded mode).
kv_bits: Number of bits to use for quantized KV cache.
kv_group_size: Group size to use for quantized KV cache.
tokens_to_generate: Number of tokens to generate beyond the prompt.
Returns:
A tuple of (results, mode) where results is a dictionary of memory usage
and mode is a string indicating the mode of the KV cache.
"""
config, model_weight_size = fetch_metadata(model_path)
bits_per_weight = compute_bits_per_weight_from_config(config)
# Determine the model class
_, model_args_class = _get_classes(config)
args = model_args_class.from_dict(config)
# Calculate the number of parameters
num_parameters = calculate_num_parameters(config, model_args_class)
# Extract model architecture parameters needed for memory calculations
num_layers = args.num_hidden_layers
num_kv_heads = getattr(args, "num_key_value_heads", args.num_attention_heads)
head_dim = calculate_head_dim(config, args)
# Default to fp16 (2 bytes per element) for KV cache unless quantized
bytes_per_element = 2
# If kv_bits and kv_group_size are not provided, try to read from config
if kv_bits is None or kv_group_size is None:
# Try to get quantization settings from config
quantization = config.get("quantization", {})
quantization_config = config.get("quantization_config", {})
# Use the first available quantization config
quant_info = quantization or quantization_config
if quant_info:
kv_bits = kv_bits or quant_info.get("bits")
kv_group_size = kv_group_size or quant_info.get("group_size")
# Calculate the model weight memory usage
bytes_per_parameter = bits_per_weight / 8
if model_weight_size:
# Use the size from safetensors index if available
model_size_gb = model_weight_size / (1024**3)
else:
# Calculate from parameter count
model_size_gb = (num_parameters * bytes_per_parameter) / (1024**3)
# Estimate tokenizer size
vocab_size = config.get("vocab_size", args.vocab_size)
fixed_overhead_bytes = 25 * 1024 * 1024
avg_token_size_bytes = 650
tokenizer_size_bytes = (vocab_size * avg_token_size_bytes) + fixed_overhead_bytes
tokenizer_size_gb = tokenizer_size_bytes / (1024**3)
# Determine the mode
mode = "Bounded" if max_kv_size else "Unbounded"
# KV length is fixed to max_kv_size in bounded mode, or context_length in unbounded mode
kv_length = max_kv_size if mode == "Bounded" else context_length
# Default step size from cache.py is 256
step_size = 256
kv_length_padded = ((kv_length + step_size - 1) // step_size) * step_size
# Calculate KV cache size based on whether quantization is used
if kv_bits and kv_group_size:
# Quantized cache calculations
groups_per_head_dim = (head_dim + kv_group_size - 1) // kv_group_size
elements_per_int = 8 * 4 // kv_bits
data_size = (
num_kv_heads * kv_length_padded * (head_dim // elements_per_int) * 4
) / (1024**3)
quant_overhead = (
num_kv_heads * kv_length_padded * groups_per_head_dim * 2 * 2
) / (1024**3)
per_layer_kv_size = 2 * (data_size + quant_overhead)
elements_per_token = (head_dim // elements_per_int) * 4
scales_zeros_per_token = groups_per_head_dim * 2 * 2
per_token_bytes = (
2 * num_kv_heads * (elements_per_token + scales_zeros_per_token)
)
per_token_increase = (per_token_bytes * num_layers) / (1024**3)
else:
# Standard fp16 cache
per_layer_kv_size = (
2 * num_kv_heads * kv_length_padded * head_dim * bytes_per_element
) / (1024**3)
per_token_increase = (
2 * num_kv_heads * head_dim * bytes_per_element * num_layers
) / (1024**3)
total_kv_cache_size = num_layers * per_layer_kv_size
# Add the memory for generated tokens if specified
if tokens_to_generate > 0:
total_kv_cache_size += tokens_to_generate * per_token_increase
# For inference in MLX, estimate activation memory
activation_size_gb = 0.03 * model_size_gb
overhead_gb = tokenizer_size_gb + activation_size_gb + (model_size_gb * 0.05)
# Total memory usage
total_memory_gb = model_size_gb + total_kv_cache_size + overhead_gb
results = {
"Weight": model_size_gb,
"KV Cache": total_kv_cache_size,
"Overhead": overhead_gb,
"Total": total_memory_gb,
"per_token_increase": per_token_increase,
}
return results, mode
def setup_arg_parser():
parser = argparse.ArgumentParser(description="MLX Model URAM Estimation Tool")
parser.add_argument("model", help="Local model path or Hugging Face repo ID.")
parser.add_argument(
"--context-length",
type=int,
default=4096,
help="Context length of the model (prompt length in unbounded mode).",
)
parser.add_argument(
"--max-kv-size", type=int, help="Max KV cache size (enables bounded mode)."
)
parser.add_argument("--kv-bits", type=int, help="Bits for KV cache quantization.")
parser.add_argument(
"--kv-group-size", type=int, help="Group size for KV cache quantization."
)
parser.add_argument(
"--tokens-to-generate",
type=int,
default=0,
help="Number of tokens to generate beyond the prompt.",
)
return parser
def print_table(
results: Dict[str, float], mode: str, tokens_to_generate: int = 0
) -> None:
"""
Print a memory usage table in a formatted way.
Args:
results: Dictionary containing memory usage data (in GB unless specified).
mode: Either "Bounded" or "Unbounded" to describe the KV Cache type.
tokens_to_generate: Number of tokens generated (optional, defaults to 0).
"""
# Construct the title dynamically
title = f"*Memory Usage Estimate ({mode} KV Cache"
if tokens_to_generate > 0:
title += f" after generating {tokens_to_generate:,} tokens"
title += "):*"
# Define table formatting constants
LINE_WIDTH = 34
ITEM_COL_WIDTH = 17
MEMORY_COL_WIDTH = 12
# Print header
print(title)
print("-" * LINE_WIDTH)
print(f"| {'Item':<{ITEM_COL_WIDTH}} | {'Memory':<{MEMORY_COL_WIDTH}} |")
print("-" * LINE_WIDTH)
# Define display order and handle missing keys gracefully
display_order = ["Weight", "KV Cache", "Overhead", "Total"]
for key in display_order:
value = results.get(key, 0.0) # Default to 0.0 if key is missing
memory_str = f"{value:.2f} GB"
# Print row (extra spaces for alignment)
if key == "Total":
print("-" * LINE_WIDTH)
print(f"| {key:<{ITEM_COL_WIDTH}} | {memory_str:>{MEMORY_COL_WIDTH}} |")
print("-" * LINE_WIDTH)
# Add footer for Unbounded mode
if mode == "Unbounded" and "per_token_increase" in results:
per_token_gb = results["per_token_increase"]
if per_token_gb > 0: # Avoid division by zero
tokens_per_gb = math.floor(1 / per_token_gb)
print(f"Additional tokens per 1GB increase: {tokens_per_gb:,}")
else:
print("Note: Per-token increase is zero or invalid.")
def main():
parser = setup_arg_parser()
args = parser.parse_args()
results, mode = estimate_mem(
args.model,
args.context_length,
args.max_kv_size,
args.kv_bits,
args.kv_group_size,
args.tokens_to_generate,
)
print_table(results, mode, args.tokens_to_generate)
if __name__ == "__main__":
main()