mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
294 lines
11 KiB
Python
294 lines
11 KiB
Python
![]() |
# wan_model_io.py
|
||
|
|
||
|
from typing import List, Tuple, Set, Dict
|
||
|
import os
|
||
|
import mlx.core as mx
|
||
|
from mlx.utils import tree_unflatten, tree_flatten
|
||
|
from safetensors import safe_open
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
import glob
|
||
|
|
||
|
|
||
|
def map_wan_2_2_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]:
|
||
|
"""Map PyTorch WAN 2.2 weights to MLX format."""
|
||
|
|
||
|
# Only add .layers to Sequential WITHIN components, not to blocks themselves
|
||
|
# blocks.N stays as blocks.N (not blocks.layers.N)
|
||
|
|
||
|
# Handle Sequential layers - PyTorch uses .0, .1, .2, MLX uses .layers.0, .layers.1, .layers.2
|
||
|
# Only for components INSIDE blocks and top-level modules
|
||
|
if ".ffn." in key and not ".layers." in key:
|
||
|
# Replace .ffn.0 with .ffn.layers.0, etc.
|
||
|
key = key.replace(".ffn.0.", ".ffn.layers.0.")
|
||
|
key = key.replace(".ffn.1.", ".ffn.layers.1.")
|
||
|
key = key.replace(".ffn.2.", ".ffn.layers.2.")
|
||
|
|
||
|
if "text_embedding." in key and not ".layers." in key:
|
||
|
for i in range(10):
|
||
|
key = key.replace(f"text_embedding.{i}.", f"text_embedding.layers.{i}.")
|
||
|
|
||
|
if "time_embedding." in key and not ".layers." in key:
|
||
|
for i in range(10):
|
||
|
key = key.replace(f"time_embedding.{i}.", f"time_embedding.layers.{i}.")
|
||
|
|
||
|
if "time_projection." in key and not ".layers." in key:
|
||
|
for i in range(10):
|
||
|
key = key.replace(f"time_projection.{i}.", f"time_projection.layers.{i}.")
|
||
|
|
||
|
# Handle conv transpose for patch_embedding
|
||
|
if "patch_embedding.weight" in key:
|
||
|
# PyTorch Conv3d: (out_channels, in_channels, D, H, W)
|
||
|
# MLX Conv3d: (out_channels, D, H, W, in_channels)
|
||
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||
|
|
||
|
return [(key, value)]
|
||
|
|
||
|
|
||
|
def check_parameter_mismatch(model, weights: Dict[str, mx.array]) -> Tuple[Set[str], Set[str]]:
|
||
|
"""
|
||
|
Check for parameter mismatches between model and weights.
|
||
|
|
||
|
Returns:
|
||
|
(model_only, weights_only): Sets of parameter names that exist only in model or weights
|
||
|
"""
|
||
|
# Get all parameter names from model
|
||
|
model_params = dict(tree_flatten(model.parameters()))
|
||
|
model_keys = set(model_params.keys())
|
||
|
|
||
|
# Remove computed buffers that aren't loaded from weights
|
||
|
computed_buffers = {'freqs'} # Add any other computed buffers here
|
||
|
model_keys = model_keys - computed_buffers
|
||
|
|
||
|
# Get all parameter names from weights
|
||
|
weight_keys = set(weights.keys())
|
||
|
|
||
|
# Find differences
|
||
|
model_only = model_keys - weight_keys
|
||
|
weights_only = weight_keys - model_keys
|
||
|
|
||
|
return model_only, weights_only
|
||
|
|
||
|
|
||
|
def load_wan_2_2_from_safetensors(
|
||
|
safetensors_path: str,
|
||
|
model,
|
||
|
float16: bool = False,
|
||
|
check_mismatch: bool = True
|
||
|
):
|
||
|
"""
|
||
|
Load WAN 2.2 Model weights from safetensors file(s) into MLX model.
|
||
|
|
||
|
Args:
|
||
|
safetensors_path: Path to safetensors file or directory
|
||
|
model: MLX model instance
|
||
|
float16: Whether to use float16 precision
|
||
|
check_mismatch: Whether to check for parameter mismatches
|
||
|
"""
|
||
|
if os.path.isdir(safetensors_path):
|
||
|
# Multiple files (14B model) - only diffusion_mlx_model files
|
||
|
pattern = os.path.join(safetensors_path, "diffusion_mlx_model*.safetensors")
|
||
|
safetensor_files = sorted(glob.glob(pattern))
|
||
|
print(f"Found {len(safetensor_files)} diffusion_mlx_model safetensors files")
|
||
|
|
||
|
# Load all files and merge weights
|
||
|
all_weights = {}
|
||
|
for file_path in safetensor_files:
|
||
|
print(f"Loading: {file_path}")
|
||
|
weights = mx.load(file_path)
|
||
|
all_weights.update(weights)
|
||
|
|
||
|
if check_mismatch:
|
||
|
model_only, weights_only = check_parameter_mismatch(model, all_weights)
|
||
|
|
||
|
if model_only:
|
||
|
print(f"\n⚠️ WARNING: {len(model_only)} parameters in model but NOT in weights:")
|
||
|
for param in sorted(model_only)[:10]: # Show first 10
|
||
|
print(f" - {param}")
|
||
|
if len(model_only) > 10:
|
||
|
print(f" ... and {len(model_only) - 10} more")
|
||
|
|
||
|
if weights_only:
|
||
|
print(f"\n⚠️ WARNING: {len(weights_only)} parameters in weights but NOT in model:")
|
||
|
for param in sorted(weights_only)[:10]: # Show first 10
|
||
|
print(f" - {param}")
|
||
|
if len(weights_only) > 10:
|
||
|
print(f" ... and {len(weights_only) - 10} more")
|
||
|
|
||
|
if not model_only and not weights_only:
|
||
|
print("\n✅ Perfect match: All parameters align between model and weights!")
|
||
|
|
||
|
model.update(tree_unflatten(list(all_weights.items())))
|
||
|
else:
|
||
|
# Single file
|
||
|
print(f"Loading single file: {safetensors_path}")
|
||
|
weights = mx.load(safetensors_path)
|
||
|
|
||
|
if check_mismatch:
|
||
|
model_only, weights_only = check_parameter_mismatch(model, weights)
|
||
|
|
||
|
if model_only:
|
||
|
print(f"\n⚠️ WARNING: {len(model_only)} parameters in model but NOT in weights:")
|
||
|
for param in sorted(model_only)[:10]: # Show first 10
|
||
|
print(f" - {param}")
|
||
|
if len(model_only) > 10:
|
||
|
print(f" ... and {len(model_only) - 10} more")
|
||
|
|
||
|
if weights_only:
|
||
|
print(f"\n⚠️ WARNING: {len(weights_only)} parameters in weights but NOT in model:")
|
||
|
for param in sorted(weights_only)[:10]: # Show first 10
|
||
|
print(f" - {param}")
|
||
|
if len(weights_only) > 10:
|
||
|
print(f" ... and {len(weights_only) - 10} more")
|
||
|
|
||
|
if not model_only and not weights_only:
|
||
|
print("\n✅ Perfect match: All parameters align between model and weights!")
|
||
|
|
||
|
model.update(tree_unflatten(list(weights.items())))
|
||
|
|
||
|
print("\nWAN 2.2 Model weights loaded successfully!")
|
||
|
return model
|
||
|
|
||
|
|
||
|
def convert_wan_2_2_safetensors_to_mlx(
|
||
|
safetensors_path: str,
|
||
|
output_path: str,
|
||
|
float16: bool = False,
|
||
|
model=None # Optional: provide model instance to check parameter alignment
|
||
|
):
|
||
|
"""
|
||
|
Convert WAN 2.2 PyTorch safetensors file to MLX weights file.
|
||
|
|
||
|
Args:
|
||
|
safetensors_path: Input safetensors file
|
||
|
output_path: Output MLX weights file (.safetensors)
|
||
|
float16: Whether to use float16 precision
|
||
|
model: Optional MLX model instance to check parameter alignment
|
||
|
"""
|
||
|
dtype = mx.float16 if float16 else mx.float32
|
||
|
|
||
|
print(f"Converting WAN 2.2 safetensors to MLX format...")
|
||
|
print(f"Input: {safetensors_path}")
|
||
|
print(f"Output: {output_path}")
|
||
|
print(f"Target dtype: {dtype}")
|
||
|
|
||
|
# Load and convert weights
|
||
|
weights = {}
|
||
|
bfloat16_count = 0
|
||
|
|
||
|
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||
|
keys = list(f.keys())
|
||
|
print(f"Processing {len(keys)} parameters...")
|
||
|
|
||
|
for key in keys:
|
||
|
tensor = f.get_tensor(key)
|
||
|
|
||
|
# Handle BFloat16
|
||
|
if tensor.dtype == torch.bfloat16:
|
||
|
bfloat16_count += 1
|
||
|
tensor = tensor.float() # Convert to float32 first
|
||
|
|
||
|
value = mx.array(tensor.numpy()).astype(dtype)
|
||
|
|
||
|
# Apply mapping
|
||
|
mapped = map_wan_2_2_weights(key, value)
|
||
|
|
||
|
for new_key, new_value in mapped:
|
||
|
weights[new_key] = new_value
|
||
|
|
||
|
if bfloat16_count > 0:
|
||
|
print(f"⚠️ Converted {bfloat16_count} BFloat16 tensors to {dtype}")
|
||
|
|
||
|
# Check parameter alignment if model provided
|
||
|
if model is not None:
|
||
|
print("\nChecking parameter alignment with model...")
|
||
|
model_only, weights_only = check_parameter_mismatch(model, weights)
|
||
|
|
||
|
if model_only:
|
||
|
print(f"\n⚠️ WARNING: {len(model_only)} parameters in model but NOT in converted weights:")
|
||
|
for param in sorted(model_only)[:10]:
|
||
|
print(f" - {param}")
|
||
|
if len(model_only) > 10:
|
||
|
print(f" ... and {len(model_only) - 10} more")
|
||
|
|
||
|
if weights_only:
|
||
|
print(f"\n⚠️ WARNING: {len(weights_only)} parameters in converted weights but NOT in model:")
|
||
|
for param in sorted(weights_only)[:10]:
|
||
|
print(f" - {param}")
|
||
|
if len(weights_only) > 10:
|
||
|
print(f" ... and {len(weights_only) - 10} more")
|
||
|
|
||
|
if not model_only and not weights_only:
|
||
|
print("\n✅ Perfect match: All parameters align between model and converted weights!")
|
||
|
|
||
|
# Save as MLX format
|
||
|
print(f"\nSaving {len(weights)} parameters to: {output_path}")
|
||
|
mx.save_safetensors(output_path, weights)
|
||
|
|
||
|
# Print a few example keys for verification
|
||
|
print("\nExample converted keys:")
|
||
|
for i, key in enumerate(sorted(weights.keys())[:10]):
|
||
|
print(f" {key}: {weights[key].shape}")
|
||
|
|
||
|
return weights
|
||
|
|
||
|
|
||
|
def convert_multiple_wan_2_2_safetensors_to_mlx(
|
||
|
checkpoint_dir: str,
|
||
|
float16: bool = False
|
||
|
):
|
||
|
"""Convert multiple WAN 2.2 PyTorch safetensors files to MLX format."""
|
||
|
# Find all PyTorch model files
|
||
|
pytorch_pattern = os.path.join(checkpoint_dir, "diffusion_pytorch_model-*.safetensors")
|
||
|
pytorch_files = sorted(glob.glob(pytorch_pattern))
|
||
|
|
||
|
if not pytorch_files:
|
||
|
raise FileNotFoundError(f"No PyTorch model files found matching: {pytorch_pattern}")
|
||
|
|
||
|
print(f"Converting {len(pytorch_files)} PyTorch files to MLX format...")
|
||
|
|
||
|
for i, pytorch_file in enumerate(pytorch_files, 1):
|
||
|
# Extract the suffix (e.g., "00001-of-00006")
|
||
|
basename = os.path.basename(pytorch_file)
|
||
|
suffix = basename.replace("diffusion_pytorch_model-", "").replace(".safetensors", "")
|
||
|
|
||
|
# Create MLX filename
|
||
|
mlx_file = os.path.join(checkpoint_dir, f"diffusion_mlx_model-{suffix}.safetensors")
|
||
|
|
||
|
print(f"\nConverting {i}/{len(pytorch_files)}: {basename}")
|
||
|
convert_wan_2_2_safetensors_to_mlx(pytorch_file, mlx_file, float16)
|
||
|
|
||
|
print("\nAll files converted successfully!")
|
||
|
|
||
|
|
||
|
def debug_wan_2_2_weight_mapping(safetensors_path: str, float16: bool = False):
|
||
|
"""
|
||
|
Debug function to see how WAN 2.2 weights are being mapped.
|
||
|
"""
|
||
|
dtype = mx.float16 if float16 else mx.float32
|
||
|
|
||
|
print("=== WAN 2.2 Weight Mapping Debug ===")
|
||
|
|
||
|
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||
|
# Check first 30 keys to see the mapping
|
||
|
for i, key in enumerate(f.keys()):
|
||
|
if i >= 30:
|
||
|
break
|
||
|
|
||
|
tensor = f.get_tensor(key)
|
||
|
|
||
|
# Handle BFloat16
|
||
|
original_dtype = tensor.dtype
|
||
|
if tensor.dtype == torch.bfloat16:
|
||
|
tensor = tensor.float()
|
||
|
|
||
|
value = mx.array(tensor.numpy()).astype(dtype)
|
||
|
|
||
|
# Apply mapping
|
||
|
mapped = map_wan_2_2_weights(key, value)
|
||
|
|
||
|
new_key, new_value = mapped[0]
|
||
|
if new_key == key:
|
||
|
print(f"UNCHANGED: {key} [{tensor.shape}]")
|
||
|
else:
|
||
|
print(f"MAPPED: {key} -> {new_key} [{tensor.shape}]")
|