import os import torch from safetensors.torch import save_file from pathlib import Path import json from wan.modules.t5 import T5Model def convert_pickle_to_safetensors( pickle_path: str, safetensors_path: str, model_class=None, model_kwargs=None, load_method: str = "weights_only" # Changed default to weights_only ): """Convert PyTorch pickle file to safetensors format.""" print(f"Loading PyTorch weights from: {pickle_path}") # Try multiple loading methods in order of preference methods_to_try = [load_method, "weights_only", "state_dict", "full_model"] for method in methods_to_try: try: if method == "weights_only": state_dict = torch.load(pickle_path, map_location='cpu', weights_only=True) elif method == "state_dict": checkpoint = torch.load(pickle_path, map_location='cpu') if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif isinstance(checkpoint, dict) and 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint elif method == "full_model": model = torch.load(pickle_path, map_location='cpu') if hasattr(model, 'state_dict'): state_dict = model.state_dict() else: state_dict = model print(f"āœ… Successfully loaded with method: {method}") break except Exception as e: print(f"āŒ Method {method} failed: {e}") continue else: raise RuntimeError(f"All loading methods failed for {pickle_path}") # Clean up the state dict state_dict = clean_state_dict(state_dict) print(f"Found {len(state_dict)} parameters") # Convert BF16 to FP32 if needed for key, tensor in state_dict.items(): if tensor.dtype == torch.bfloat16: state_dict[key] = tensor.to(torch.float32) print(f"Converted {key} from bfloat16 to float32") # Save as safetensors print(f"Saving to safetensors: {safetensors_path}") os.makedirs(os.path.dirname(safetensors_path), exist_ok=True) save_file(state_dict, safetensors_path) print("āœ… T5 conversion complete!") return state_dict def clean_state_dict(state_dict): """ Clean up state dict by removing unwanted prefixes or keys. """ cleaned = {} for key, value in state_dict.items(): # Remove common prefixes that might interfere clean_key = key if clean_key.startswith('module.'): clean_key = clean_key[7:] if clean_key.startswith('model.'): clean_key = clean_key[6:] cleaned[clean_key] = value if clean_key != key: print(f"Cleaned key: {key} -> {clean_key}") return cleaned def load_with_your_torch_model(pickle_path: str, model_class, **model_kwargs): """ Load pickle weights into your specific PyTorch T5 model implementation. Args: pickle_path: Path to pickle file model_class: Your T5Encoder class **model_kwargs: Arguments for your model constructor """ print("Method 1: Loading into your PyTorch T5 model") # Initialize your model model = model_class(**model_kwargs) # Load checkpoint checkpoint = torch.load(pickle_path, map_location='cpu') # Handle different checkpoint formats if isinstance(checkpoint, dict): if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: # Assume the dict IS the state dict state_dict = checkpoint else: # Assume it's a model object state_dict = checkpoint.state_dict() # Clean and load state_dict = clean_state_dict(state_dict) model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore missing keys return model, state_dict def explore_pickle_file(pickle_path: str): """ Explore the contents of a pickle file to understand its structure. """ print(f"šŸ” Exploring pickle file: {pickle_path}") try: # Try loading with weights_only first (safer) print("\n--- Trying weights_only=True ---") try: data = torch.load(pickle_path, map_location='cpu', weights_only=True) print(f"āœ… Loaded with weights_only=True") print(f"Type: {type(data)}") if isinstance(data, dict): print(f"Dictionary with {len(data)} keys:") for i, key in enumerate(data.keys()): print(f" {key}: {type(data[key])}") if hasattr(data[key], 'shape'): print(f" Shape: {data[key].shape}") if i >= 9: # Show first 10 keys break except Exception as e: print(f"āŒ weights_only=True failed: {e}") # Try regular loading print("\n--- Trying regular loading ---") data = torch.load(pickle_path, map_location='cpu') print(f"āœ… Loaded successfully") print(f"Type: {type(data)}") if hasattr(data, 'state_dict'): print("šŸ“‹ Found state_dict method") state_dict = data.state_dict() print(f"State dict has {len(state_dict)} parameters") elif isinstance(data, dict): print(f"šŸ“‹ Dictionary with keys: {list(data.keys())}") # Check for common checkpoint keys if 'state_dict' in data: print("Found 'state_dict' key") print(f"state_dict has {len(data['state_dict'])} parameters") elif 'model' in data: print("Found 'model' key") print(f"model has {len(data['model'])} parameters") except Exception as e: print(f"āŒ Failed to load: {e}") def full_conversion_pipeline( pickle_path: str, safetensors_path: str, torch_model_class=None, model_kwargs=None ): """ Complete pipeline: pickle -> safetensors -> ready for MLX conversion """ print("šŸš€ Starting full conversion pipeline") print("="*50) # Step 1: Explore the pickle file print("Step 1: Exploring pickle file structure") explore_pickle_file(pickle_path) # Step 2: Convert to safetensors print(f"\nStep 2: Converting to safetensors") # Try different loading methods for method in ["weights_only", "state_dict", "full_model"]: try: print(f"\nTrying load method: {method}") state_dict = convert_pickle_to_safetensors( pickle_path, safetensors_path, model_class=torch_model_class, model_kwargs=model_kwargs, load_method=method ) print(f"āœ… Success with method: {method}") break except Exception as e: print(f"āŒ Method {method} failed: {e}") continue else: print("āŒ All methods failed!") return None print(f"\nšŸŽ‰ Conversion complete!") print(f"Safetensors file saved to: {safetensors_path}") print(f"Ready for MLX conversion using the previous script!") return state_dict # Example usage def example_usage(): """ Example of how to use the conversion functions """ # Your model class and parameters # class YourT5Encoder(nn.Module): # def __init__(self, vocab_size, d_model, ...): # ... pickle_file = "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" safetensors_file = "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.safetensors" # Method 1: Quick exploration print("=== Exploring pickle file ===") explore_pickle_file(pickle_file) # Method 2: Full pipeline print("\n=== Full conversion pipeline ===") state_dict = full_conversion_pipeline( pickle_file, safetensors_file, torch_model_class=T5Model, # Your model class model_kwargs={ 'vocab_size': 256384, 'd_model': 4096, 'num_layers': 24, # ... other parameters } ) # Method 3: Direct conversion (if you know the format) print("\n=== Direct conversion ===") # state_dict = convert_pickle_to_safetensors( # pickle_file, # safetensors_file, # load_method="state_dict" # or "weights_only" or "full_model" # ) if __name__ == "__main__": example_usage()