mlx-examples/video/Wan2.1/wan/text2video.py
2025-07-28 17:07:26 -07:00

312 lines
12 KiB
Python

import glob
import gc
import logging
import math
import os
import random
import sys
from tqdm import tqdm
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .wan_model_io import load_wan_from_safetensors
class WanT2V:
def __init__(
self,
config,
checkpoint_dir,
):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
"""
self.config = config
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = mx.float16 if config.param_dtype == 'float16' else mx.float32
# Initialize T5 text encoder - with automatic conversion
t5_checkpoint_path = os.path.join(checkpoint_dir, config.t5_checkpoint)
mlx_t5_path = t5_checkpoint_path.replace('.safetensors', '_mlx.safetensors')
if not os.path.exists(mlx_t5_path):
# Check if it's a .pth file that needs conversion
pth_path = t5_checkpoint_path.replace('.safetensors', '.pth')
if os.path.exists(pth_path):
logging.info(f"Converting T5 PyTorch model to safetensors: {pth_path}")
from .t5_torch_to_sf import convert_pickle_to_safetensors
convert_pickle_to_safetensors(pth_path, t5_checkpoint_path, load_method="weights_only")
# Convert torch safetensors to MLX safetensors
from .t5_model_io import convert_safetensors_to_mlx_weights
convert_safetensors_to_mlx_weights(t5_checkpoint_path, mlx_t5_path, float16=(self.param_dtype == mx.float16))
else:
raise FileNotFoundError(f"T5 checkpoint not found: {t5_checkpoint_path} or {pth_path}")
t5_checkpoint_path = mlx_t5_path # Use the MLX version
logging.info(f"Loading T5 text encoder... from {t5_checkpoint_path}")
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
checkpoint_path=t5_checkpoint_path,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer))
# Initialize VAE
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
# Initialize VAE - with automatic conversion
vae_path = os.path.join(checkpoint_dir, config.vae_checkpoint)
if not os.path.exists(vae_path):
# Check for PyTorch VAE file to convert
pth_vae_path = vae_path.replace('_mlx.safetensors', '.pth')
if not os.path.exists(pth_vae_path):
# Try alternative naming
pth_vae_path = os.path.join(checkpoint_dir, 'Wan2.1_VAE.pth')
if os.path.exists(pth_vae_path):
logging.info(f"Converting VAE PyTorch model to MLX: {pth_vae_path}")
from .vae_model_io import convert_pytorch_to_mlx
convert_pytorch_to_mlx(pth_vae_path, vae_path, float16=(self.param_dtype == mx.float16))
else:
raise FileNotFoundError(f"VAE checkpoint not found: {vae_path} or {pth_vae_path}")
logging.info("Loading VAE...")
self.vae = WanVAE(vae_pth=vae_path)
# Initialize WanModel
logging.info(f"Creating WanModel from {checkpoint_dir}")
# Create model with config parameters
self.model = WanModel(
model_type='t2v',
patch_size=config.patch_size,
text_len=config.text_len,
in_dim=16,
dim=config.dim,
ffn_dim=config.ffn_dim,
freq_dim=config.freq_dim,
text_dim=4096,
out_dim=16,
num_heads=config.num_heads,
num_layers=config.num_layers,
window_size=getattr(config, 'window_size', (-1, -1)),
qk_norm=getattr(config, 'qk_norm', True),
cross_attn_norm=getattr(config, 'cross_attn_norm', True),
eps=getattr(config, 'eps', 1e-6)
)
# In WanT2V.__init__, replace the model loading section with:
# Load pretrained weights - with automatic conversion
model_path = os.path.join(checkpoint_dir, "diffusion_pytorch_model_mlx.safetensors")
if not os.path.exists(model_path):
# Check for directory with multiple files (14B model)
pattern = os.path.join(checkpoint_dir, "diffusion_mlx_model*.safetensors")
mlx_files = glob.glob(pattern)
if not mlx_files:
# No MLX files found, look for PyTorch files to convert
pytorch_path = os.path.join(checkpoint_dir, "diffusion_pytorch_model.safetensors")
pytorch_pattern = os.path.join(checkpoint_dir, "diffusion_pytorch_model-*.safetensors")
pytorch_files = glob.glob(pytorch_pattern)
if os.path.exists(pytorch_path):
logging.info(f"Converting single PyTorch model to MLX: {pytorch_path}")
from .wan_model_io import convert_safetensors_to_mlx_weights
convert_safetensors_to_mlx_weights(
pytorch_path,
model_path,
float16=(self.param_dtype == mx.float16)
)
elif pytorch_files:
logging.info(f"Converting {len(pytorch_files)} PyTorch model files to MLX")
from .wan_model_io import convert_multiple_safetensors_to_mlx
convert_multiple_safetensors_to_mlx(
checkpoint_dir,
float16=(self.param_dtype == mx.float16)
)
else:
raise FileNotFoundError(f"No PyTorch model files found in {checkpoint_dir}")
# Load the model (now MLX format exists)
if os.path.exists(model_path):
# Single file (1.3B)
logging.info(f"Loading model weights from {model_path}")
self.model = load_wan_from_safetensors(model_path, self.model, float16=(self.param_dtype == mx.float16))
else:
# Multiple files (14B)
logging.info(f"Loading model weights from directory {checkpoint_dir}")
self.model = load_wan_from_safetensors(checkpoint_dir, self.model, float16=(self.param_dtype == mx.float16))
# Set model to eval mode
self.model.eval()
self.sp_size = 1 # No sequence parallelism in MLX version
self.sample_neg_prompt = config.sample_neg_prompt
def generate(self,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tuple[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 50):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save memory
Returns:
mx.array:
Generated video frames tensor. Dimensions: (C, N, H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames
- H: Frame height (from size)
- W: Frame width (from size)
"""
# Preprocess
F = frame_num
target_shape = (
self.vae.model.z_dim,
(F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2]
)
seq_len = math.ceil(
(target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size
) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# Set random seed
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
mx.random.seed(seed)
# Encode text prompts
logging.info("Encoding text prompts...")
context = self.text_encoder([input_prompt])
context_null = self.text_encoder([n_prompt])
# Generate initial noise
noise = [
mx.random.normal(
shape=target_shape,
dtype=mx.float32
)
]
# Initialize scheduler
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(
sampling_steps, shift=shift
)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False
)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
sigmas=sampling_sigmas
)
else:
raise NotImplementedError(f"Unsupported solver: {sample_solver}")
# Sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
logging.info(f"Generating video with {len(timesteps)} steps...")
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = mx.array([t])
# Model predictions
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c
)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null
)[0]
# Classifier-free guidance
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond
)
# Scheduler step
temp_x0 = sample_scheduler.step(
mx.expand_dims(noise_pred, 0),
t,
mx.expand_dims(latents[0], 0),
return_dict=False
)[0]
latents = [mx.squeeze(temp_x0, 0)]
mx.eval(latents)
x0 = latents
# Decode latents to video
logging.info("Decoding latents to video...")
videos = self.vae.decode(x0)
# Memory cleanup
del noise, latents, sample_scheduler
if offload_model:
mx.eval(videos) # Ensure computation is complete
gc.collect()
return videos[0]