mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Implement Wan2.2
This commit is contained in:
401
video/Wan2.2/wan/text2video.py
Normal file
401
video/Wan2.2/wan/text2video.py
Normal file
@@ -0,0 +1,401 @@
|
||||
# MLX implementation of text2video.py
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, List, Dict, Any, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from .modules.model import WanModel
|
||||
from .modules.t5 import T5EncoderModel
|
||||
from .modules.vae2_1 import Wan2_1_VAE
|
||||
from .utils.fm_solvers import (
|
||||
FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
from .wan_model_io import convert_wan_2_2_safetensors_to_mlx, convert_multiple_wan_2_2_safetensors_to_mlx, load_wan_2_2_from_safetensors
|
||||
|
||||
class WanT2V:
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir: str,
|
||||
device_id: int = 0,
|
||||
convert_model_dtype: bool = False,
|
||||
):
|
||||
r"""
|
||||
Initializes the Wan text-to-video generation model components for MLX.
|
||||
|
||||
Args:
|
||||
config (EasyDict):
|
||||
Object containing model parameters initialized from config.py
|
||||
checkpoint_dir (`str`):
|
||||
Path to directory containing model checkpoints
|
||||
device_id (`int`, *optional*, defaults to 0):
|
||||
Device id (kept for compatibility, MLX handles device automatically)
|
||||
convert_model_dtype (`bool`, *optional*, defaults to False):
|
||||
Convert DiT model parameters dtype to 'config.param_dtype'.
|
||||
"""
|
||||
self.config = config
|
||||
self.vae_stride = config.vae_stride
|
||||
self.patch_size = config.patch_size
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.boundary = config.boundary
|
||||
# Convert PyTorch dtype to MLX dtype
|
||||
if str(config.param_dtype) == 'torch.bfloat16':
|
||||
self.param_dtype = mx.bfloat16
|
||||
elif str(config.param_dtype) == 'torch.float16':
|
||||
self.param_dtype = mx.float16
|
||||
elif str(config.param_dtype) == 'torch.float32':
|
||||
self.param_dtype = mx.float32
|
||||
else:
|
||||
self.param_dtype = mx.float32 # default
|
||||
|
||||
# Initialize T5 text encoder
|
||||
print(f"checkpoint_dir is: {checkpoint_dir}")
|
||||
t5_checkpoint_path = os.path.join(checkpoint_dir, config.t5_checkpoint)
|
||||
mlx_t5_path = t5_checkpoint_path.replace('.safetensors', '_mlx.safetensors')
|
||||
if not os.path.exists(mlx_t5_path):
|
||||
# Check if it's a .pth file that needs conversion
|
||||
pth_path = t5_checkpoint_path.replace('.safetensors', '.pth')
|
||||
if os.path.exists(pth_path):
|
||||
logging.info(f"Converting T5 PyTorch model to safetensors: {pth_path}")
|
||||
from .t5_torch_to_sf import convert_pickle_to_safetensors
|
||||
convert_pickle_to_safetensors(pth_path, t5_checkpoint_path, load_method="weights_only")
|
||||
# Convert torch safetensors to MLX safetensors
|
||||
from .t5_model_io import convert_safetensors_to_mlx_weights
|
||||
convert_safetensors_to_mlx_weights(t5_checkpoint_path, mlx_t5_path, float16=(self.param_dtype == mx.float16))
|
||||
else:
|
||||
raise FileNotFoundError(f"T5 checkpoint not found: {t5_checkpoint_path} or {pth_path}")
|
||||
|
||||
t5_checkpoint_path = mlx_t5_path # Use the MLX version
|
||||
logging.info(f"Loading T5 text encoder... from {t5_checkpoint_path}")
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
checkpoint_path=t5_checkpoint_path,
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer))
|
||||
|
||||
# Initialize VAE - with automatic conversion
|
||||
vae_path = os.path.join(checkpoint_dir, config.vae_checkpoint)
|
||||
if not os.path.exists(vae_path):
|
||||
# Check for PyTorch VAE file to convert
|
||||
pth_vae_path = vae_path.replace('_mlx.safetensors', '.pth')
|
||||
if not os.path.exists(pth_vae_path):
|
||||
# Try alternative naming
|
||||
pth_vae_path = os.path.join(checkpoint_dir, 'Wan2.1_VAE.pth')
|
||||
|
||||
if os.path.exists(pth_vae_path):
|
||||
logging.info(f"Converting VAE PyTorch model to MLX: {pth_vae_path}")
|
||||
from .vae_model_io import convert_pytorch_to_mlx
|
||||
convert_pytorch_to_mlx(pth_vae_path, vae_path, float16=(self.param_dtype == mx.float16))
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE checkpoint not found: {vae_path} or {pth_vae_path}")
|
||||
|
||||
logging.info("Loading VAE...")
|
||||
self.vae = Wan2_1_VAE(vae_pth=vae_path)
|
||||
|
||||
# Load low and high noise models
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
|
||||
# Helper function to load model with automatic conversion
|
||||
def load_model_with_conversion(checkpoint_dir, subfolder, config, param_dtype):
|
||||
"""Load model with automatic PyTorch to MLX conversion if needed."""
|
||||
|
||||
# Look for existing MLX files
|
||||
mlx_single = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model_mlx.safetensors")
|
||||
mlx_pattern = os.path.join(checkpoint_dir, subfolder, "diffusion_mlx_model*.safetensors")
|
||||
mlx_files = glob.glob(mlx_pattern)
|
||||
|
||||
# If no MLX files, convert PyTorch files
|
||||
if not os.path.exists(mlx_single) and not mlx_files:
|
||||
pytorch_single = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model.safetensors")
|
||||
pytorch_pattern = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model-*.safetensors")
|
||||
pytorch_files = glob.glob(pytorch_pattern)
|
||||
|
||||
if os.path.exists(pytorch_single):
|
||||
logging.info(f"Converting PyTorch model to MLX: {pytorch_single}")
|
||||
convert_wan_2_2_safetensors_to_mlx(
|
||||
pytorch_single,
|
||||
mlx_single,
|
||||
float16=(param_dtype == mx.float16)
|
||||
)
|
||||
elif pytorch_files:
|
||||
logging.info(f"Converting {len(pytorch_files)} PyTorch files to MLX")
|
||||
convert_multiple_wan_2_2_safetensors_to_mlx(
|
||||
os.path.join(checkpoint_dir, subfolder),
|
||||
float16=(param_dtype == mx.float16)
|
||||
)
|
||||
mlx_files = glob.glob(mlx_pattern) # Update file list
|
||||
else:
|
||||
raise FileNotFoundError(f"No model files found in {os.path.join(checkpoint_dir, subfolder)}")
|
||||
|
||||
# Create model
|
||||
model = WanModel(
|
||||
model_type='t2v',
|
||||
patch_size=config.patch_size,
|
||||
text_len=config.text_len,
|
||||
in_dim=16,
|
||||
dim=config.dim,
|
||||
ffn_dim=config.ffn_dim,
|
||||
freq_dim=config.freq_dim,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=config.num_heads,
|
||||
num_layers=config.num_layers,
|
||||
window_size=getattr(config, 'window_size', (-1, -1)),
|
||||
qk_norm=getattr(config, 'qk_norm', True),
|
||||
cross_attn_norm=getattr(config, 'cross_attn_norm', True),
|
||||
eps=getattr(config, 'eps', 1e-6)
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if os.path.exists(mlx_single):
|
||||
logging.info(f"Loading single MLX file: {mlx_single}")
|
||||
model = load_wan_2_2_from_safetensors(mlx_single, model, float16=(param_dtype == mx.float16))
|
||||
else:
|
||||
logging.info(f"Loading multiple MLX files from: {os.path.join(checkpoint_dir, subfolder)}")
|
||||
model = load_wan_2_2_from_safetensors(
|
||||
os.path.join(checkpoint_dir, subfolder),
|
||||
model,
|
||||
float16=(param_dtype == mx.float16)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
# Load both models
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
logging.info("Loading low noise model")
|
||||
self.low_noise_model = load_model_with_conversion(
|
||||
checkpoint_dir,
|
||||
config.low_noise_checkpoint,
|
||||
self.config,
|
||||
self.param_dtype
|
||||
)
|
||||
self.low_noise_model = self._configure_model(self.low_noise_model, convert_model_dtype)
|
||||
|
||||
logging.info("Loading high noise model")
|
||||
self.high_noise_model = load_model_with_conversion(
|
||||
checkpoint_dir,
|
||||
config.high_noise_checkpoint,
|
||||
self.config,
|
||||
self.param_dtype
|
||||
)
|
||||
self.high_noise_model = self._configure_model(self.high_noise_model, convert_model_dtype)
|
||||
|
||||
self.sp_size = 1 # No sequence parallel in single device
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
def _configure_model(self, model: nn.Module, convert_model_dtype: bool) -> nn.Module:
|
||||
"""
|
||||
Configures a model object for MLX.
|
||||
|
||||
Args:
|
||||
model (nn.Module):
|
||||
The model instance to configure.
|
||||
convert_model_dtype (`bool`):
|
||||
Convert model parameters dtype to 'config.param_dtype'.
|
||||
|
||||
Returns:
|
||||
nn.Module:
|
||||
The configured model.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
if convert_model_dtype:
|
||||
# In MLX, we would need to manually convert parameters
|
||||
# This would be implemented in the actual model class
|
||||
pass
|
||||
|
||||
return model
|
||||
|
||||
def _prepare_model_for_timestep(self, t, boundary, offload_model):
|
||||
"""
|
||||
Prepares and returns the required model for the current timestep.
|
||||
"""
|
||||
if t.item() >= boundary:
|
||||
required_model_name = 'high_noise_model'
|
||||
offload_model_name = 'low_noise_model'
|
||||
else:
|
||||
required_model_name = 'low_noise_model'
|
||||
offload_model_name = 'high_noise_model'
|
||||
|
||||
# MLX doesn't need the CPU offloading logic, just return the right model
|
||||
return getattr(self, required_model_name)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_prompt: str,
|
||||
size: Tuple[int, int] = (1280, 720),
|
||||
frame_num: int = 81,
|
||||
shift: float = 5.0,
|
||||
sample_solver: str = 'unipc',
|
||||
sampling_steps: int = 50,
|
||||
guide_scale: Union[float, Tuple[float, float]] = 5.0,
|
||||
n_prompt: str = "",
|
||||
seed: int = -1,
|
||||
offload_model: bool = True
|
||||
) -> Optional[mx.array]:
|
||||
r"""
|
||||
Generates video frames from text prompt using diffusion process.
|
||||
|
||||
Args:
|
||||
input_prompt (`str`):
|
||||
Text prompt for content generation
|
||||
size (`tuple[int]`, *optional*, defaults to (1280,720)):
|
||||
Controls video resolution, (width,height).
|
||||
frame_num (`int`, *optional*, defaults to 81):
|
||||
How many frames to sample from a video. The number should be 4n+1
|
||||
shift (`float`, *optional*, defaults to 5.0):
|
||||
Noise schedule shift parameter. Affects temporal dynamics
|
||||
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
||||
Solver used to sample the video.
|
||||
sampling_steps (`int`, *optional*, defaults to 50):
|
||||
Number of diffusion sampling steps.
|
||||
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
|
||||
Classifier-free guidance scale.
|
||||
n_prompt (`str`, *optional*, defaults to ""):
|
||||
Negative prompt for content exclusion.
|
||||
seed (`int`, *optional*, defaults to -1):
|
||||
Random seed for noise generation. If -1, use random seed.
|
||||
offload_model (`bool`, *optional*, defaults to True):
|
||||
Not used in MLX version (kept for compatibility)
|
||||
|
||||
Returns:
|
||||
mx.array:
|
||||
Generated video frames tensor. Dimensions: (C, N, H, W)
|
||||
"""
|
||||
# Preprocess
|
||||
guide_scale = (guide_scale, guide_scale) if isinstance(
|
||||
guide_scale, float) else guide_scale
|
||||
|
||||
F = frame_num
|
||||
target_shape = (
|
||||
self.vae.model.z_dim,
|
||||
(F - 1) // self.vae_stride[0] + 1,
|
||||
size[1] // self.vae_stride[1],
|
||||
size[0] // self.vae_stride[2]
|
||||
)
|
||||
|
||||
seq_len = math.ceil(
|
||||
(target_shape[2] * target_shape[3]) /
|
||||
(self.patch_size[1] * self.patch_size[2]) *
|
||||
target_shape[1] / self.sp_size
|
||||
) * self.sp_size
|
||||
|
||||
if n_prompt == "":
|
||||
n_prompt = self.sample_neg_prompt
|
||||
|
||||
# Set random seed
|
||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Encode text prompts
|
||||
context = self.text_encoder([input_prompt])
|
||||
context_null = self.text_encoder([n_prompt])
|
||||
|
||||
# Generate initial noise
|
||||
noise = [
|
||||
mx.random.normal(
|
||||
shape=target_shape,
|
||||
dtype=mx.float32
|
||||
)
|
||||
]
|
||||
|
||||
# Set boundary
|
||||
boundary = self.boundary * self.num_train_timesteps
|
||||
|
||||
# Initialize scheduler
|
||||
if sample_solver == 'unipc':
|
||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False
|
||||
)
|
||||
sample_scheduler.set_timesteps(
|
||||
sampling_steps, shift=shift
|
||||
)
|
||||
timesteps = sample_scheduler.timesteps
|
||||
elif sample_solver == 'dpm++':
|
||||
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False
|
||||
)
|
||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
||||
timesteps, _ = retrieve_timesteps(
|
||||
sample_scheduler,
|
||||
sigmas=sampling_sigmas
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported solver.")
|
||||
|
||||
# Sample videos
|
||||
latents = noise
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Denoising loop
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = mx.array([t])
|
||||
|
||||
# Select model based on timestep
|
||||
model = self._prepare_model_for_timestep(
|
||||
t, boundary, offload_model
|
||||
)
|
||||
sample_guide_scale = guide_scale[1] if t.item() >= boundary else guide_scale[0]
|
||||
|
||||
# Model predictions
|
||||
noise_pred_cond = model(
|
||||
latent_model_input, t=timestep, **arg_c
|
||||
)[0]
|
||||
mx.eval(noise_pred_cond) # Force evaluation
|
||||
|
||||
noise_pred_uncond = model(
|
||||
latent_model_input, t=timestep, **arg_null
|
||||
)[0]
|
||||
mx.eval(noise_pred_uncond) # Force evaluation
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred = noise_pred_uncond + sample_guide_scale * (
|
||||
noise_pred_cond - noise_pred_uncond
|
||||
)
|
||||
mx.eval(noise_pred) # Force evaluation
|
||||
|
||||
# Scheduler step
|
||||
temp_x0 = sample_scheduler.step(
|
||||
mx.expand_dims(noise_pred, axis=0),
|
||||
t,
|
||||
mx.expand_dims(latents[0], axis=0),
|
||||
return_dict=False
|
||||
)[0]
|
||||
latents = [mx.squeeze(temp_x0, axis=0)]
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Decode final latents
|
||||
x0 = latents
|
||||
videos = self.vae.decode(x0)
|
||||
|
||||
# Cleanup
|
||||
del noise, latents
|
||||
del sample_scheduler
|
||||
if offload_model:
|
||||
gc.collect()
|
||||
|
||||
return videos[0]
|
||||
Reference in New Issue
Block a user